mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-25 02:50:42 +00:00
Merge remote-tracking branch 'ado_wai_ort/DmlDev' into user/dwayner/DML1.8forORT1.10
This commit is contained in:
commit
289b1bdc86
28 changed files with 2286 additions and 1071 deletions
|
|
@ -122,6 +122,10 @@ namespace Windows::AI::MachineLearning::Adapter
|
|||
// Operator supports true 64-bit tensors directly, no strides needed.
|
||||
// So fallback to strided 32-bit only occurs when the device lacks 64-bit support.
|
||||
bool prefer64BitTensorsDirectly = false;
|
||||
|
||||
// The operator supports emulation for uint64/int64 even if the hardware doesn't
|
||||
// support native uint64/int64 data types.
|
||||
bool support64BitTensorsViaEmulation = false;
|
||||
};
|
||||
|
||||
using InternalRegistrationInfoMap = std::unordered_map<onnxruntime::KernelDef*, std::shared_ptr<InternalRegistrationInfo>>;
|
||||
|
|
|
|||
|
|
@ -345,6 +345,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
|
|||
bool supportedWith64BitTensorsVia32BitStrides,
|
||||
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp,
|
||||
bool prefer64BitTensorsDirectly,
|
||||
bool support64BitTensorsViaEmulation,
|
||||
_In_reads_(constantCpuInputCount) const uint32_t* requiredConstantCpuInputs,
|
||||
uint32_t constantCpuInputCount) const noexcept
|
||||
{
|
||||
|
|
@ -472,6 +473,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
|
|||
regInfo->supportedWith64BitTensorsVia32BitStrides = supportedWith64BitTensorsVia32BitStrides;
|
||||
regInfo->supportedWith64BitTensorsVia32BitStridesFromAnyEp = supportedWith64BitTensorsVia32BitStridesFromAnyEp;
|
||||
regInfo->prefer64BitTensorsDirectly = prefer64BitTensorsDirectly;
|
||||
regInfo->support64BitTensorsViaEmulation = support64BitTensorsViaEmulation;
|
||||
|
||||
// Only internal operators support usage in DML graphs
|
||||
if (supportsGraph)
|
||||
|
|
@ -546,7 +548,8 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
|
|||
requiredConstantCpuInputs ||
|
||||
supportedWith64BitTensorsVia32BitStrides ||
|
||||
supportedWith64BitTensorsVia32BitStridesFromAnyEp ||
|
||||
prefer64BitTensorsDirectly)
|
||||
prefer64BitTensorsDirectly ||
|
||||
support64BitTensorsViaEmulation)
|
||||
{
|
||||
ORT_THROW_HR(E_INVALIDARG);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ class AbiCustomRegistry : public WRL::Base<IMLOperatorRegistry, IMLOperatorRegis
|
|||
bool supportedWith64BitTensorsVia32BitStrides = false,
|
||||
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = false,
|
||||
bool prefer64BitTensorsDirectly = false,
|
||||
bool support64BitTensorsViaEmulation = false,
|
||||
_In_reads_(constantCpuInputCount) const uint32_t* requiredConstantCpuInputs = nullptr,
|
||||
uint32_t constantCpuInputCount = 0) const noexcept override;
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ DML_TENSOR_DATA_TYPE GetDmlDataTypeFromMlDataTypeNoThrow(MLOperatorTensorDataTyp
|
|||
case MLOperatorTensorDataType::String: return DML_TENSOR_DATA_TYPE_UNKNOWN;
|
||||
case MLOperatorTensorDataType::Bool: return DML_TENSOR_DATA_TYPE_UINT8;
|
||||
case MLOperatorTensorDataType::Float16: return DML_TENSOR_DATA_TYPE_FLOAT16;
|
||||
case MLOperatorTensorDataType::Double: return DML_TENSOR_DATA_TYPE_UNKNOWN;
|
||||
case MLOperatorTensorDataType::Double: return DML_TENSOR_DATA_TYPE_FLOAT64;
|
||||
case MLOperatorTensorDataType::UInt32: return DML_TENSOR_DATA_TYPE_UINT32;
|
||||
case MLOperatorTensorDataType::UInt64: return DML_TENSOR_DATA_TYPE_UINT64;
|
||||
case MLOperatorTensorDataType::Complex64: return DML_TENSOR_DATA_TYPE_UNKNOWN;
|
||||
|
|
@ -119,7 +119,7 @@ uint32_t GetSupportedDeviceDataTypeMask(IDMLDevice* dmlDevice)
|
|||
uint32_t deviceTypeMask = 0u;
|
||||
|
||||
// Form the bitmask of all supported data types.
|
||||
for (uint32_t i = 0; i <= DML_TENSOR_DATA_TYPE_INT8; ++i)
|
||||
for (uint32_t i = 0; i <= DML_TENSOR_DATA_TYPE_INT64; ++i)
|
||||
{
|
||||
DML_FEATURE_QUERY_TENSOR_DATA_TYPE_SUPPORT dataTypeQuery = { static_cast<DML_TENSOR_DATA_TYPE>(i) };
|
||||
DML_FEATURE_DATA_TENSOR_DATA_TYPE_SUPPORT dataTypeSupport = {};
|
||||
|
|
|
|||
|
|
@ -96,4 +96,24 @@ namespace Dml
|
|||
|
||||
return minimumImpliedSizeInBytes;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CastToClampedScalarUnion(DML_TENSOR_DATA_TYPE dataType, T value, DML_SCALAR_UNION* outputValue)
|
||||
{
|
||||
switch (dataType)
|
||||
{
|
||||
case DML_TENSOR_DATA_TYPE_UINT8: outputValue->UInt8 = clamp_cast<uint8_t, T>(value); break;
|
||||
case DML_TENSOR_DATA_TYPE_UINT16: outputValue->UInt16 = clamp_cast<uint16_t, T>(value); break;
|
||||
case DML_TENSOR_DATA_TYPE_UINT32: outputValue->UInt32 = clamp_cast<uint32_t, T>(value); break;
|
||||
case DML_TENSOR_DATA_TYPE_UINT64: outputValue->UInt64 = clamp_cast<uint64_t, T>(value); break;
|
||||
case DML_TENSOR_DATA_TYPE_INT8: outputValue->Int8 = clamp_cast<int8_t, T>(value); break;
|
||||
case DML_TENSOR_DATA_TYPE_INT16: outputValue->Int16 = clamp_cast<int16_t, T>(value); break;
|
||||
case DML_TENSOR_DATA_TYPE_INT32: outputValue->Int32 = clamp_cast<int32_t, T>(value); break;
|
||||
case DML_TENSOR_DATA_TYPE_INT64: outputValue->Int64 = clamp_cast<int64_t, T>(value); break;
|
||||
case DML_TENSOR_DATA_TYPE_FLOAT16: outputValue->Float32 = clamp_cast<float, T>(value); break;
|
||||
case DML_TENSOR_DATA_TYPE_FLOAT32: outputValue->Float32 = clamp_cast<float, T>(value); break;
|
||||
case DML_TENSOR_DATA_TYPE_FLOAT64: outputValue->Float64 = clamp_cast<double, T>(value); break;
|
||||
default: assert(false);
|
||||
}
|
||||
}
|
||||
} // namespace Dml
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ struct EnumTraits<DML_TENSOR_TYPE>
|
|||
template <>
|
||||
struct EnumTraits<DML_OPERATOR_TYPE>
|
||||
{
|
||||
static constexpr auto ValueCount = 141;
|
||||
static constexpr auto ValueCount = 153;
|
||||
static constexpr size_t ActivationFunctionCount = 20;
|
||||
};
|
||||
|
||||
|
|
@ -86,7 +86,7 @@ struct EnumTraits<DML_FEATURE>
|
|||
template <>
|
||||
struct EnumTraits<DML_FEATURE_LEVEL>
|
||||
{
|
||||
static constexpr auto ValueCount = 4;
|
||||
static constexpr auto ValueCount = 8;
|
||||
};
|
||||
|
||||
template <>
|
||||
|
|
@ -225,6 +225,24 @@ struct OperatorDescTraits<DML_ELEMENT_WISE_CLIP_OPERATOR_DESC>
|
|||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_CLIP;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_ELEMENT_WISE_CLIP1_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_CLIP1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_ELEMENT_WISE_COS_OPERATOR_DESC>
|
||||
{
|
||||
|
|
@ -363,6 +381,18 @@ struct OperatorDescTraits<DML_ELEMENT_WISE_SQRT_OPERATOR_DESC>
|
|||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_SQRT;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_ELEMENT_WISE_DIFFERENCE_SQUARE_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_ELEMENT_WISE_ATAN_YX_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_ATAN_YX;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_ELEMENT_WISE_SUBTRACT_OPERATOR_DESC>
|
||||
{
|
||||
|
|
@ -483,6 +513,12 @@ struct OperatorDescTraits<DML_PADDING_OPERATOR_DESC>
|
|||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_PADDING;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_PADDING1_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_PADDING1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_VALUE_SCALE_2D_OPERATOR_DESC>
|
||||
{
|
||||
|
|
@ -531,6 +567,18 @@ struct OperatorDescTraits<DML_BATCH_NORMALIZATION_OPERATOR_DESC>
|
|||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_BATCH_NORMALIZATION;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_BATCH_NORMALIZATION_GRAD_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_BATCH_NORMALIZATION_GRAD;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_DESC>
|
||||
{
|
||||
|
|
@ -543,6 +591,12 @@ struct OperatorDescTraits<DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_DESC>
|
|||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_LP_NORMALIZATION_OPERATOR_DESC>
|
||||
{
|
||||
|
|
@ -579,6 +633,12 @@ struct OperatorDescTraits<DML_ELEMENT_WISE_IS_NAN_OPERATOR_DESC>
|
|||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_IS_NAN;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_ELEMENT_WISE_NEGATE_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_NEGATE;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_ELEMENT_WISE_ERF_OPERATOR_DESC>
|
||||
{
|
||||
|
|
@ -717,6 +777,12 @@ struct OperatorDescTraits<DML_CUMULATIVE_SUMMATION_OPERATOR_DESC>
|
|||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_CUMULATIVE_SUMMATION;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_CUMULATIVE_PRODUCT_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_CUMULATIVE_PRODUCT;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_REVERSE_SUBSEQUENCES_OPERATOR_DESC>
|
||||
{
|
||||
|
|
@ -891,12 +957,42 @@ struct OperatorDescTraits<DML_ROI_ALIGN_OPERATOR_DESC>
|
|||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ROI_ALIGN;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_ROI_ALIGN1_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ROI_ALIGN1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_GATHER_ND1_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_GATHER_ND1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_ROI_ALIGN_GRAD_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ROI_ALIGN_GRAD;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_BATCH_NORMALIZATION_TRAINING;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_ACTIVATION_ELU_OPERATOR_DESC>
|
||||
{
|
||||
|
|
@ -1071,6 +1167,24 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_CLIP>
|
|||
using DescType = DML_ELEMENT_WISE_CLIP_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_CLIP1>
|
||||
{
|
||||
using DescType = DML_ELEMENT_WISE_CLIP1_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD>
|
||||
{
|
||||
using DescType = DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1>
|
||||
{
|
||||
using DescType = DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_COS>
|
||||
{
|
||||
|
|
@ -1209,6 +1323,18 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_SQRT>
|
|||
using DescType = DML_ELEMENT_WISE_SQRT_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE>
|
||||
{
|
||||
using DescType = DML_ELEMENT_WISE_DIFFERENCE_SQUARE_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_ATAN_YX>
|
||||
{
|
||||
using DescType = DML_ELEMENT_WISE_ATAN_YX_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_SUBTRACT>
|
||||
{
|
||||
|
|
@ -1329,6 +1455,12 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_PADDING>
|
|||
using DescType = DML_PADDING_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_PADDING1>
|
||||
{
|
||||
using DescType = DML_PADDING1_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_VALUE_SCALE_2D>
|
||||
{
|
||||
|
|
@ -1377,6 +1509,18 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_BATCH_NORMALIZATION>
|
|||
using DescType = DML_BATCH_NORMALIZATION_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_BATCH_NORMALIZATION_GRAD>
|
||||
{
|
||||
using DescType = DML_BATCH_NORMALIZATION_GRAD_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD>
|
||||
{
|
||||
using DescType = DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION>
|
||||
{
|
||||
|
|
@ -1389,6 +1533,12 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_LOCAL_RESPONSE_NORMALI
|
|||
using DescType = DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD>
|
||||
{
|
||||
using DescType = DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_LP_NORMALIZATION>
|
||||
{
|
||||
|
|
@ -1425,6 +1575,12 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_IS_NAN>
|
|||
using DescType = DML_ELEMENT_WISE_IS_NAN_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_NEGATE>
|
||||
{
|
||||
using DescType = DML_ELEMENT_WISE_NEGATE_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_ERF>
|
||||
{
|
||||
|
|
@ -1563,6 +1719,12 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_CUMULATIVE_SUMMATION>
|
|||
using DescType = DML_CUMULATIVE_SUMMATION_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_CUMULATIVE_PRODUCT>
|
||||
{
|
||||
using DescType = DML_CUMULATIVE_PRODUCT_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_REVERSE_SUBSEQUENCES>
|
||||
{
|
||||
|
|
@ -1737,12 +1899,42 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ROI_ALIGN>
|
|||
using DescType = DML_ROI_ALIGN_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ROI_ALIGN1>
|
||||
{
|
||||
using DescType = DML_ROI_ALIGN1_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_GATHER_ND1>
|
||||
{
|
||||
using DescType = DML_GATHER_ND1_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR>
|
||||
{
|
||||
using DescType = DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD>
|
||||
{
|
||||
using DescType = DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ROI_ALIGN_GRAD>
|
||||
{
|
||||
using DescType = DML_ROI_ALIGN_GRAD_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_BATCH_NORMALIZATION_TRAINING>
|
||||
{
|
||||
using DescType = DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_ELU>
|
||||
{
|
||||
|
|
@ -1894,6 +2086,12 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
|
|||
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_CEIL_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ELEMENT_WISE_CLIP:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_CLIP_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ELEMENT_WISE_CLIP1:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_CLIP1_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ELEMENT_WISE_COS:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_COS_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ELEMENT_WISE_DIVIDE:
|
||||
|
|
@ -1940,6 +2138,10 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
|
|||
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_SIN_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ELEMENT_WISE_SQRT:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_SQRT_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_DIFFERENCE_SQUARE_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ELEMENT_WISE_ATAN_YX:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_ATAN_YX_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ELEMENT_WISE_SUBTRACT:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_SUBTRACT_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ELEMENT_WISE_TAN:
|
||||
|
|
@ -1980,6 +2182,8 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
|
|||
return std::invoke(std::forward<Visitor>(visitor), DML_JOIN_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_PADDING:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_PADDING_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_PADDING1:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_PADDING1_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_VALUE_SCALE_2D:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_VALUE_SCALE_2D_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_UPSAMPLE_2D:
|
||||
|
|
@ -1996,10 +2200,16 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
|
|||
return std::invoke(std::forward<Visitor>(visitor), DML_TOP_K_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_BATCH_NORMALIZATION:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_BATCH_NORMALIZATION_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_BATCH_NORMALIZATION_GRAD:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_BATCH_NORMALIZATION_GRAD_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_LP_NORMALIZATION:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_LP_NORMALIZATION_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_RNN:
|
||||
|
|
@ -2012,6 +2222,8 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
|
|||
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_SIGN_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ELEMENT_WISE_IS_NAN:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_IS_NAN_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ELEMENT_WISE_NEGATE:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_NEGATE_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ELEMENT_WISE_ERF:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_ERF_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ELEMENT_WISE_SINH:
|
||||
|
|
@ -2058,6 +2270,8 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
|
|||
return std::invoke(std::forward<Visitor>(visitor), DML_FILL_VALUE_SEQUENCE_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_CUMULATIVE_SUMMATION:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_CUMULATIVE_SUMMATION_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_CUMULATIVE_PRODUCT:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_CUMULATIVE_PRODUCT_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_REVERSE_SUBSEQUENCES:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_REVERSE_SUBSEQUENCES_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_GATHER_ELEMENTS:
|
||||
|
|
@ -2116,8 +2330,18 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
|
|||
return std::invoke(std::forward<Visitor>(visitor), DML_ADAM_OPTIMIZER_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ROI_ALIGN:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ROI_ALIGN_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ROI_ALIGN1:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ROI_ALIGN1_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_GATHER_ND1:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_GATHER_ND1_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ROI_ALIGN_GRAD:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ROI_ALIGN_GRAD_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ACTIVATION_ELU:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_ELU_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ACTIVATION_CELU:
|
||||
|
|
@ -2180,6 +2404,9 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value)
|
|||
case DML_OPERATOR_ELEMENT_WISE_ATAN: return "DML_OPERATOR_ELEMENT_WISE_ATAN";
|
||||
case DML_OPERATOR_ELEMENT_WISE_CEIL: return "DML_OPERATOR_ELEMENT_WISE_CEIL";
|
||||
case DML_OPERATOR_ELEMENT_WISE_CLIP: return "DML_OPERATOR_ELEMENT_WISE_CLIP";
|
||||
case DML_OPERATOR_ELEMENT_WISE_CLIP1: return "DML_OPERATOR_ELEMENT_WISE_CLIP1";
|
||||
case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD: return "DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD";
|
||||
case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1: return "DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1";
|
||||
case DML_OPERATOR_ELEMENT_WISE_COS: return "DML_OPERATOR_ELEMENT_WISE_COS";
|
||||
case DML_OPERATOR_ELEMENT_WISE_DIVIDE: return "DML_OPERATOR_ELEMENT_WISE_DIVIDE";
|
||||
case DML_OPERATOR_ELEMENT_WISE_EXP: return "DML_OPERATOR_ELEMENT_WISE_EXP";
|
||||
|
|
@ -2203,6 +2430,8 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value)
|
|||
case DML_OPERATOR_ELEMENT_WISE_RECIP: return "DML_OPERATOR_ELEMENT_WISE_RECIP";
|
||||
case DML_OPERATOR_ELEMENT_WISE_SIN: return "DML_OPERATOR_ELEMENT_WISE_SIN";
|
||||
case DML_OPERATOR_ELEMENT_WISE_SQRT: return "DML_OPERATOR_ELEMENT_WISE_SQRT";
|
||||
case DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE: return "DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE";
|
||||
case DML_OPERATOR_ELEMENT_WISE_ATAN_YX: return "DML_OPERATOR_ELEMENT_WISE_ATAN_YX";
|
||||
case DML_OPERATOR_ELEMENT_WISE_SUBTRACT: return "DML_OPERATOR_ELEMENT_WISE_SUBTRACT";
|
||||
case DML_OPERATOR_ELEMENT_WISE_TAN: return "DML_OPERATOR_ELEMENT_WISE_TAN";
|
||||
case DML_OPERATOR_ELEMENT_WISE_THRESHOLD: return "DML_OPERATOR_ELEMENT_WISE_THRESHOLD";
|
||||
|
|
@ -2223,6 +2452,7 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value)
|
|||
case DML_OPERATOR_SPLIT: return "DML_OPERATOR_SPLIT";
|
||||
case DML_OPERATOR_JOIN: return "DML_OPERATOR_JOIN";
|
||||
case DML_OPERATOR_PADDING: return "DML_OPERATOR_PADDING";
|
||||
case DML_OPERATOR_PADDING1: return "DML_OPERATOR_PADDING1";
|
||||
case DML_OPERATOR_VALUE_SCALE_2D: return "DML_OPERATOR_VALUE_SCALE_2D";
|
||||
case DML_OPERATOR_UPSAMPLE_2D: return "DML_OPERATOR_UPSAMPLE_2D";
|
||||
case DML_OPERATOR_GATHER: return "DML_OPERATOR_GATHER";
|
||||
|
|
@ -2231,14 +2461,18 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value)
|
|||
case DML_OPERATOR_TILE: return "DML_OPERATOR_TILE";
|
||||
case DML_OPERATOR_TOP_K: return "DML_OPERATOR_TOP_K";
|
||||
case DML_OPERATOR_BATCH_NORMALIZATION: return "DML_OPERATOR_BATCH_NORMALIZATION";
|
||||
case DML_OPERATOR_BATCH_NORMALIZATION_GRAD: return "DML_OPERATOR_BATCH_NORMALIZATION_GRAD";
|
||||
case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD: return "DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD";
|
||||
case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION: return "DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION";
|
||||
case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION: return "DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION";
|
||||
case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD: return "DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD";
|
||||
case DML_OPERATOR_LP_NORMALIZATION: return "DML_OPERATOR_LP_NORMALIZATION";
|
||||
case DML_OPERATOR_RNN: return "DML_OPERATOR_RNN";
|
||||
case DML_OPERATOR_LSTM: return "DML_OPERATOR_LSTM";
|
||||
case DML_OPERATOR_GRU: return "DML_OPERATOR_GRU";
|
||||
case DML_OPERATOR_ELEMENT_WISE_SIGN: return "DML_OPERATOR_ELEMENT_WISE_SIGN";
|
||||
case DML_OPERATOR_ELEMENT_WISE_IS_NAN: return "DML_OPERATOR_ELEMENT_WISE_IS_NAN";
|
||||
case DML_OPERATOR_ELEMENT_WISE_NEGATE: return "DML_OPERATOR_ELEMENT_WISE_NEGATE";
|
||||
case DML_OPERATOR_ELEMENT_WISE_ERF: return "DML_OPERATOR_ELEMENT_WISE_ERF";
|
||||
case DML_OPERATOR_ELEMENT_WISE_SINH: return "DML_OPERATOR_ELEMENT_WISE_SINH";
|
||||
case DML_OPERATOR_ELEMENT_WISE_COSH: return "DML_OPERATOR_ELEMENT_WISE_COSH";
|
||||
|
|
@ -2262,6 +2496,7 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value)
|
|||
case DML_OPERATOR_FILL_VALUE_CONSTANT: return "DML_OPERATOR_FILL_VALUE_CONSTANT";
|
||||
case DML_OPERATOR_FILL_VALUE_SEQUENCE: return "DML_OPERATOR_FILL_VALUE_SEQUENCE";
|
||||
case DML_OPERATOR_CUMULATIVE_SUMMATION: return "DML_OPERATOR_CUMULATIVE_SUMMATION";
|
||||
case DML_OPERATOR_CUMULATIVE_PRODUCT: return "DML_OPERATOR_CUMULATIVE_PRODUCT";
|
||||
case DML_OPERATOR_REVERSE_SUBSEQUENCES: return "DML_OPERATOR_REVERSE_SUBSEQUENCES";
|
||||
case DML_OPERATOR_GATHER_ELEMENTS: return "DML_OPERATOR_GATHER_ELEMENTS";
|
||||
case DML_OPERATOR_GATHER_ND: return "DML_OPERATOR_GATHER_ND";
|
||||
|
|
@ -2291,7 +2526,12 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value)
|
|||
case DML_OPERATOR_SLICE_GRAD: return "DML_OPERATOR_SLICE_GRAD";
|
||||
case DML_OPERATOR_ADAM_OPTIMIZER: return "DML_OPERATOR_ADAM_OPTIMIZER";
|
||||
case DML_OPERATOR_ROI_ALIGN: return "DML_OPERATOR_ROI_ALIGN";
|
||||
case DML_OPERATOR_ROI_ALIGN1: return "DML_OPERATOR_ROI_ALIGN1";
|
||||
case DML_OPERATOR_GATHER_ND1: return "DML_OPERATOR_GATHER_ND1";
|
||||
case DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR: return "DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR";
|
||||
case DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD: return "DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD";
|
||||
case DML_OPERATOR_ROI_ALIGN_GRAD: return "DML_OPERATOR_ROI_ALIGN_GRAD";
|
||||
case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING: return "DML_OPERATOR_BATCH_NORMALIZATION_TRAINING";
|
||||
default:
|
||||
assert(false);
|
||||
return "<unknown>";
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ enum DML_SCHEMA_FIELD_TYPE
|
|||
DML_SCHEMA_FIELD_TYPE_SCALE_BIAS,
|
||||
DML_SCHEMA_FIELD_TYPE_SIZE_2D,
|
||||
DML_SCHEMA_FIELD_TYPE_SCALAR_UNION,
|
||||
DML_SCHEMA_FIELD_TYPE_BOOL,
|
||||
};
|
||||
|
||||
enum DML_SCHEMA_OPERATOR_SUPPORT_FLAGS
|
||||
|
|
@ -170,6 +171,56 @@ constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_CLIP_OPERATOR_SCHEMA {
|
|||
DML_ELEMENT_WISE_CLIP_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_CLIP1_OPERATOR_SCHEMA_FIELDS[6] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "MinMaxDataType", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALAR_UNION, "Min", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALAR_UNION, "Max", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_CLIP1_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_ELEMENT_WISE_CLIP1",
|
||||
DML_OPERATOR_ELEMENT_WISE_CLIP1,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION,
|
||||
6,
|
||||
DML_ELEMENT_WISE_CLIP1_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_SCHEMA_FIELDS[5] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputGradientTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputGradientTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Min", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Max", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD",
|
||||
DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION,
|
||||
5,
|
||||
DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_SCHEMA_FIELDS[6] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputGradientTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputGradientTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "MinMaxDataType", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALAR_UNION, "Min", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALAR_UNION, "Max", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1",
|
||||
DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION,
|
||||
6,
|
||||
DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_COS_OPERATOR_SCHEMA_FIELDS[3] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
|
|
@ -493,6 +544,34 @@ constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_SQRT_OPERATOR_SCHEMA {
|
|||
DML_ELEMENT_WISE_SQRT_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_DIFFERENCE_SQUARE_OPERATOR_SCHEMA_FIELDS[3] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_DIFFERENCE_SQUARE_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE",
|
||||
DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION,
|
||||
3,
|
||||
DML_ELEMENT_WISE_DIFFERENCE_SQUARE_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_ATAN_YX_OPERATOR_SCHEMA_FIELDS[3] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_ATAN_YX_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_ELEMENT_WISE_ATAN_YX",
|
||||
DML_OPERATOR_ELEMENT_WISE_ATAN_YX,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION,
|
||||
3,
|
||||
DML_ELEMENT_WISE_ATAN_YX_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_SUBTRACT_OPERATOR_SCHEMA_FIELDS[3] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false },
|
||||
|
|
@ -828,6 +907,25 @@ constexpr DML_OPERATOR_SCHEMA DML_PADDING_OPERATOR_SCHEMA {
|
|||
DML_PADDING_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_PADDING1_OPERATOR_SCHEMA_FIELDS[8] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "PaddingMode", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "PaddingValueDataType", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALAR_UNION, "PaddingValue", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_PADDING1_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_PADDING1",
|
||||
DML_OPERATOR_PADDING1,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
|
||||
8,
|
||||
DML_PADDING1_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_VALUE_SCALE_2D_OPERATOR_SCHEMA_FIELDS[5] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
|
|
@ -954,6 +1052,46 @@ constexpr DML_OPERATOR_SCHEMA DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA {
|
|||
DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA_FIELDS[9] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputGradientTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "MeanTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "VarianceTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ScaleTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputGradientTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleGradientTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputBiasGradientTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Epsilon", true },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_BATCH_NORMALIZATION_GRAD",
|
||||
DML_OPERATOR_BATCH_NORMALIZATION_GRAD,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
|
||||
9,
|
||||
DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA_FIELDS[9] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputGradientTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "MeanTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "VarianceTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ScaleTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputGradientTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleGradientTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputBiasGradientTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Epsilon", true },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD",
|
||||
DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
|
||||
9,
|
||||
DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_SCHEMA_FIELDS[8] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ScaleTensor", true },
|
||||
|
|
@ -991,6 +1129,25 @@ constexpr DML_OPERATOR_SCHEMA DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_SCHEMA {
|
|||
DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_SCHEMA_FIELDS[8] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputGradientTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputGradientTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_BOOL, "CrossChannel", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "LocalSize", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Alpha", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Beta", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Bias", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD",
|
||||
DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
|
||||
8,
|
||||
DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_LP_NORMALIZATION_OPERATOR_SCHEMA_FIELDS[5] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
|
|
@ -1106,6 +1263,19 @@ constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_IS_NAN_OPERATOR_SCHEMA {
|
|||
DML_ELEMENT_WISE_IS_NAN_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_NEGATE_OPERATOR_SCHEMA_FIELDS[2] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_NEGATE_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_ELEMENT_WISE_NEGATE",
|
||||
DML_OPERATOR_ELEMENT_WISE_NEGATE,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
|
||||
2,
|
||||
DML_ELEMENT_WISE_NEGATE_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_ERF_OPERATOR_SCHEMA_FIELDS[3] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
|
|
@ -1426,8 +1596,8 @@ constexpr DML_SCHEMA_FIELD DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA_FIELDS[5] {
|
|||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Axis", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "HasExclusiveSum", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AxisDirection", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "HasExclusiveSum", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA {
|
||||
|
|
@ -1438,6 +1608,22 @@ constexpr DML_OPERATOR_SCHEMA DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA {
|
|||
DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_CUMULATIVE_PRODUCT_OPERATOR_SCHEMA_FIELDS[5] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Axis", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AxisDirection", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "HasExclusiveProduct", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_CUMULATIVE_PRODUCT_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_CUMULATIVE_PRODUCT",
|
||||
DML_OPERATOR_CUMULATIVE_PRODUCT,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION,
|
||||
5,
|
||||
DML_CUMULATIVE_PRODUCT_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA_FIELDS[4] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "SequenceLengthsTensor", false },
|
||||
|
|
@ -1448,7 +1634,7 @@ constexpr DML_SCHEMA_FIELD DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA_FIELDS[4] {
|
|||
constexpr DML_OPERATOR_SCHEMA DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_REVERSE_SUBSEQUENCES",
|
||||
DML_OPERATOR_REVERSE_SUBSEQUENCES,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
|
||||
4,
|
||||
DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
|
@ -1809,17 +1995,23 @@ constexpr DML_OPERATOR_SCHEMA DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA {
|
|||
DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA_FIELDS[3] {
|
||||
constexpr DML_SCHEMA_FIELD DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA_FIELDS[9] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputGradientTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputGradientTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_MAX_POOLING_GRAD",
|
||||
DML_OPERATOR_MAX_POOLING_GRAD,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
|
||||
3,
|
||||
9,
|
||||
DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
|
|
@ -1932,6 +2124,31 @@ constexpr DML_OPERATOR_SCHEMA DML_ROI_ALIGN_OPERATOR_SCHEMA {
|
|||
DML_ROI_ALIGN_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_ROI_ALIGN1_OPERATOR_SCHEMA_FIELDS[14] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ROITensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BatchIndicesTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "ReductionFunction", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "InterpolationMode", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "SpatialScaleX", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "SpatialScaleY", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "InputPixelOffset", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "OutputPixelOffset", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "OutOfBoundsInputValue", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "MinimumSamplesPerOutput", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "MaximumSamplesPerOutput", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AlignRegionsToCorners", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_ROI_ALIGN1_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_ROI_ALIGN1",
|
||||
DML_OPERATOR_ROI_ALIGN1,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
|
||||
14,
|
||||
DML_ROI_ALIGN1_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_GATHER_ND1_OPERATOR_SCHEMA_FIELDS[6] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "IndicesTensor", false },
|
||||
|
|
@ -1949,6 +2166,87 @@ constexpr DML_OPERATOR_SCHEMA DML_GATHER_ND1_OPERATOR_SCHEMA {
|
|||
DML_GATHER_ND1_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA_FIELDS[4] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputZeroPointTensor", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR",
|
||||
DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
|
||||
4,
|
||||
DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA_FIELDS[9] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AScaleTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AZeroPointTensor", true },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BScaleTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BZeroPointTensor", true },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputZeroPointTensor", true },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD",
|
||||
DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
|
||||
9,
|
||||
DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA_FIELDS[15] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", true },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputGradientTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ROITensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BatchIndicesTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputGradientTensor", true },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputROIGradientTensor", true },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "ReductionFunction", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "InterpolationMode", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "SpatialScaleX", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "SpatialScaleY", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "InputPixelOffset", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "OutputPixelOffset", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "MinimumSamplesPerOutput", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "MaximumSamplesPerOutput", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AlignRegionsToCorners", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_ROI_ALIGN_GRAD",
|
||||
DML_OPERATOR_ROI_ALIGN_GRAD,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
|
||||
15,
|
||||
DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA_FIELDS[9] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ScaleTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "FusedAddTensor", true },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputMeanTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputVarianceTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Epsilon", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC, "FusedActivation", true },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_BATCH_NORMALIZATION_TRAINING",
|
||||
DML_OPERATOR_BATCH_NORMALIZATION_TRAINING,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
|
||||
9,
|
||||
DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_ACTIVATION_ELU_OPERATOR_SCHEMA_FIELDS[3] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
|
|
|
|||
|
|
@ -73,6 +73,38 @@ inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_CLIP_OPERATOR
|
|||
OperatorField(&DML_ELEMENT_WISE_CLIP_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<FLOAT>(desc.Max))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_CLIP1_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_ELEMENT_WISE_CLIP1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
|
||||
OperatorField(&DML_ELEMENT_WISE_CLIP1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
|
||||
OperatorField(&DML_ELEMENT_WISE_CLIP1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_SCALE_BIAS*>(desc.ScaleBias))),
|
||||
OperatorField(&DML_ELEMENT_WISE_CLIP1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<UINT>(desc.MinMaxDataType))),
|
||||
OperatorField(&DML_ELEMENT_WISE_CLIP1_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<DML_SCALAR_UNION>(desc.Min))),
|
||||
OperatorField(&DML_ELEMENT_WISE_CLIP1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<DML_SCALAR_UNION>(desc.Max))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
|
||||
OperatorField(&DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputGradientTensor))),
|
||||
OperatorField(&DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputGradientTensor))),
|
||||
OperatorField(&DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<FLOAT>(desc.Min))),
|
||||
OperatorField(&DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<FLOAT>(desc.Max))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
|
||||
OperatorField(&DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputGradientTensor))),
|
||||
OperatorField(&DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputGradientTensor))),
|
||||
OperatorField(&DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<UINT>(desc.MinMaxDataType))),
|
||||
OperatorField(&DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<DML_SCALAR_UNION>(desc.Min))),
|
||||
OperatorField(&DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<DML_SCALAR_UNION>(desc.Max))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_COS_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
|
|
@ -258,6 +290,22 @@ inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_SQRT_OPERATOR
|
|||
OperatorField(&DML_ELEMENT_WISE_SQRT_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_SCALE_BIAS*>(desc.ScaleBias))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_DIFFERENCE_SQUARE_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_ELEMENT_WISE_DIFFERENCE_SQUARE_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.ATensor))),
|
||||
OperatorField(&DML_ELEMENT_WISE_DIFFERENCE_SQUARE_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BTensor))),
|
||||
OperatorField(&DML_ELEMENT_WISE_DIFFERENCE_SQUARE_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_ATAN_YX_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_ELEMENT_WISE_ATAN_YX_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.ATensor))),
|
||||
OperatorField(&DML_ELEMENT_WISE_ATAN_YX_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BTensor))),
|
||||
OperatorField(&DML_ELEMENT_WISE_ATAN_YX_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_SUBTRACT_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
|
|
@ -473,6 +521,19 @@ inline std::vector<OperatorField> GetFields(const DML_PADDING_OPERATOR_DESC& des
|
|||
OperatorField(&DML_PADDING_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<const UINT*>(desc.EndPadding), desc.DimensionCount)),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_PADDING1_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_PADDING1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
|
||||
OperatorField(&DML_PADDING1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
|
||||
OperatorField(&DML_PADDING1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<UINT>(desc.PaddingMode))),
|
||||
OperatorField(&DML_PADDING1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<UINT>(desc.PaddingValueDataType))),
|
||||
OperatorField(&DML_PADDING1_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<DML_SCALAR_UNION>(desc.PaddingValue))),
|
||||
OperatorField(&DML_PADDING1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<UINT>(desc.DimensionCount))),
|
||||
OperatorField(&DML_PADDING1_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<const UINT*>(desc.StartPadding), desc.DimensionCount)),
|
||||
OperatorField(&DML_PADDING1_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<const UINT*>(desc.EndPadding), desc.DimensionCount)),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_VALUE_SCALE_2D_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
|
|
@ -551,6 +612,34 @@ inline std::vector<OperatorField> GetFields(const DML_BATCH_NORMALIZATION_OPERAT
|
|||
OperatorField(&DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast<const DML_OPERATOR_DESC*>(desc.FusedActivation))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_BATCH_NORMALIZATION_GRAD_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
|
||||
OperatorField(&DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputGradientTensor))),
|
||||
OperatorField(&DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.MeanTensor))),
|
||||
OperatorField(&DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.VarianceTensor))),
|
||||
OperatorField(&DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.ScaleTensor))),
|
||||
OperatorField(&DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputGradientTensor))),
|
||||
OperatorField(&DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputScaleGradientTensor))),
|
||||
OperatorField(&DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputBiasGradientTensor))),
|
||||
OperatorField(&DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast<FLOAT>(desc.Epsilon))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
|
||||
OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputGradientTensor))),
|
||||
OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.MeanTensor))),
|
||||
OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.VarianceTensor))),
|
||||
OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.ScaleTensor))),
|
||||
OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputGradientTensor))),
|
||||
OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputScaleGradientTensor))),
|
||||
OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputBiasGradientTensor))),
|
||||
OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast<FLOAT>(desc.Epsilon))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
|
|
@ -576,6 +665,19 @@ inline std::vector<OperatorField> GetFields(const DML_LOCAL_RESPONSE_NORMALIZATI
|
|||
OperatorField(&DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<FLOAT>(desc.Bias))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
|
||||
OperatorField(&DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputGradientTensor))),
|
||||
OperatorField(&DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputGradientTensor))),
|
||||
OperatorField(&DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<bool>(desc.CrossChannel))),
|
||||
OperatorField(&DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<UINT>(desc.LocalSize))),
|
||||
OperatorField(&DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<FLOAT>(desc.Alpha))),
|
||||
OperatorField(&DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<FLOAT>(desc.Beta))),
|
||||
OperatorField(&DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<FLOAT>(desc.Bias))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_LP_NORMALIZATION_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
|
|
@ -655,6 +757,13 @@ inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_IS_NAN_OPERAT
|
|||
OperatorField(&DML_ELEMENT_WISE_IS_NAN_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_NEGATE_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_ELEMENT_WISE_NEGATE_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
|
||||
OperatorField(&DML_ELEMENT_WISE_NEGATE_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_ERF_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
|
|
@ -845,8 +954,18 @@ inline std::vector<OperatorField> GetFields(const DML_CUMULATIVE_SUMMATION_OPERA
|
|||
OperatorField(&DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
|
||||
OperatorField(&DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
|
||||
OperatorField(&DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<UINT>(desc.Axis))),
|
||||
OperatorField(&DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<UINT>(desc.HasExclusiveSum))),
|
||||
OperatorField(&DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<UINT>(desc.AxisDirection))),
|
||||
OperatorField(&DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<UINT>(desc.AxisDirection))),
|
||||
OperatorField(&DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<UINT>(desc.HasExclusiveSum))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_CUMULATIVE_PRODUCT_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_CUMULATIVE_PRODUCT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
|
||||
OperatorField(&DML_CUMULATIVE_PRODUCT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
|
||||
OperatorField(&DML_CUMULATIVE_PRODUCT_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<UINT>(desc.Axis))),
|
||||
OperatorField(&DML_CUMULATIVE_PRODUCT_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<UINT>(desc.AxisDirection))),
|
||||
OperatorField(&DML_CUMULATIVE_PRODUCT_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<UINT>(desc.HasExclusiveProduct))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_REVERSE_SUBSEQUENCES_OPERATOR_DESC& desc)
|
||||
|
|
@ -1094,6 +1213,12 @@ inline std::vector<OperatorField> GetFields(const DML_MAX_POOLING_GRAD_OPERATOR_
|
|||
OperatorField(&DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
|
||||
OperatorField(&DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputGradientTensor))),
|
||||
OperatorField(&DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputGradientTensor))),
|
||||
OperatorField(&DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<UINT>(desc.DimensionCount))),
|
||||
OperatorField(&DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const UINT*>(desc.Strides), desc.DimensionCount)),
|
||||
OperatorField(&DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<const UINT*>(desc.WindowSize), desc.DimensionCount)),
|
||||
OperatorField(&DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<const UINT*>(desc.StartPadding), desc.DimensionCount)),
|
||||
OperatorField(&DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<const UINT*>(desc.EndPadding), desc.DimensionCount)),
|
||||
OperatorField(&DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast<const UINT*>(desc.Dilations), desc.DimensionCount)),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_RANDOM_GENERATOR_OPERATOR_DESC& desc)
|
||||
|
|
@ -1169,6 +1294,25 @@ inline std::vector<OperatorField> GetFields(const DML_ROI_ALIGN_OPERATOR_DESC& d
|
|||
OperatorField(&DML_ROI_ALIGN_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast<UINT>(desc.MaximumSamplesPerOutput))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_ROI_ALIGN1_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
|
||||
OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.ROITensor))),
|
||||
OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BatchIndicesTensor))),
|
||||
OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
|
||||
OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<UINT>(desc.ReductionFunction))),
|
||||
OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<UINT>(desc.InterpolationMode))),
|
||||
OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<FLOAT>(desc.SpatialScaleX))),
|
||||
OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<FLOAT>(desc.SpatialScaleY))),
|
||||
OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast<FLOAT>(desc.InputPixelOffset))),
|
||||
OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast<FLOAT>(desc.OutputPixelOffset))),
|
||||
OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast<FLOAT>(desc.OutOfBoundsInputValue))),
|
||||
OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast<UINT>(desc.MinimumSamplesPerOutput))),
|
||||
OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[12], ToOperatorFieldType(static_cast<UINT>(desc.MaximumSamplesPerOutput))),
|
||||
OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[13], ToOperatorFieldType(static_cast<UINT>(desc.AlignRegionsToCorners))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_GATHER_ND1_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
|
|
@ -1180,6 +1324,63 @@ inline std::vector<OperatorField> GetFields(const DML_GATHER_ND1_OPERATOR_DESC&
|
|||
OperatorField(&DML_GATHER_ND1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<UINT>(desc.BatchDimensionCount))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
|
||||
OperatorField(&DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
|
||||
OperatorField(&DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputScaleTensor))),
|
||||
OperatorField(&DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputZeroPointTensor))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.ATensor))),
|
||||
OperatorField(&DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.AScaleTensor))),
|
||||
OperatorField(&DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.AZeroPointTensor))),
|
||||
OperatorField(&DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BTensor))),
|
||||
OperatorField(&DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BScaleTensor))),
|
||||
OperatorField(&DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BZeroPointTensor))),
|
||||
OperatorField(&DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputScaleTensor))),
|
||||
OperatorField(&DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputZeroPointTensor))),
|
||||
OperatorField(&DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_ROI_ALIGN_GRAD_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
|
||||
OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputGradientTensor))),
|
||||
OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.ROITensor))),
|
||||
OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BatchIndicesTensor))),
|
||||
OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputGradientTensor))),
|
||||
OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputROIGradientTensor))),
|
||||
OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<UINT>(desc.ReductionFunction))),
|
||||
OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<UINT>(desc.InterpolationMode))),
|
||||
OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast<FLOAT>(desc.SpatialScaleX))),
|
||||
OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast<FLOAT>(desc.SpatialScaleY))),
|
||||
OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast<FLOAT>(desc.InputPixelOffset))),
|
||||
OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast<FLOAT>(desc.OutputPixelOffset))),
|
||||
OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[12], ToOperatorFieldType(static_cast<UINT>(desc.MinimumSamplesPerOutput))),
|
||||
OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[13], ToOperatorFieldType(static_cast<UINT>(desc.MaximumSamplesPerOutput))),
|
||||
OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[14], ToOperatorFieldType(static_cast<UINT>(desc.AlignRegionsToCorners))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
|
||||
OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.ScaleTensor))),
|
||||
OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BiasTensor))),
|
||||
OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.FusedAddTensor))),
|
||||
OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
|
||||
OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputMeanTensor))),
|
||||
OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputVarianceTensor))),
|
||||
OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<FLOAT>(desc.Epsilon))),
|
||||
OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast<const DML_OPERATOR_DESC*>(desc.FusedActivation))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_ELU_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
|
|
@ -1351,6 +1552,9 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType)
|
|||
case DML_OPERATOR_ELEMENT_WISE_ATAN: return DML_ELEMENT_WISE_ATAN_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ELEMENT_WISE_CEIL: return DML_ELEMENT_WISE_CEIL_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ELEMENT_WISE_CLIP: return DML_ELEMENT_WISE_CLIP_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ELEMENT_WISE_CLIP1: return DML_ELEMENT_WISE_CLIP1_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD: return DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1: return DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ELEMENT_WISE_COS: return DML_ELEMENT_WISE_COS_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ELEMENT_WISE_DIVIDE: return DML_ELEMENT_WISE_DIVIDE_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ELEMENT_WISE_EXP: return DML_ELEMENT_WISE_EXP_OPERATOR_SCHEMA;
|
||||
|
|
@ -1374,6 +1578,8 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType)
|
|||
case DML_OPERATOR_ELEMENT_WISE_RECIP: return DML_ELEMENT_WISE_RECIP_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ELEMENT_WISE_SIN: return DML_ELEMENT_WISE_SIN_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ELEMENT_WISE_SQRT: return DML_ELEMENT_WISE_SQRT_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE: return DML_ELEMENT_WISE_DIFFERENCE_SQUARE_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ELEMENT_WISE_ATAN_YX: return DML_ELEMENT_WISE_ATAN_YX_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ELEMENT_WISE_SUBTRACT: return DML_ELEMENT_WISE_SUBTRACT_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ELEMENT_WISE_TAN: return DML_ELEMENT_WISE_TAN_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ELEMENT_WISE_THRESHOLD: return DML_ELEMENT_WISE_THRESHOLD_OPERATOR_SCHEMA;
|
||||
|
|
@ -1394,6 +1600,7 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType)
|
|||
case DML_OPERATOR_SPLIT: return DML_SPLIT_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_JOIN: return DML_JOIN_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_PADDING: return DML_PADDING_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_PADDING1: return DML_PADDING1_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_VALUE_SCALE_2D: return DML_VALUE_SCALE_2D_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_UPSAMPLE_2D: return DML_UPSAMPLE_2D_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_GATHER: return DML_GATHER_OPERATOR_SCHEMA;
|
||||
|
|
@ -1402,14 +1609,18 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType)
|
|||
case DML_OPERATOR_TILE: return DML_TILE_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_TOP_K: return DML_TOP_K_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_BATCH_NORMALIZATION: return DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_BATCH_NORMALIZATION_GRAD: return DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD: return DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION: return DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION: return DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD: return DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_LP_NORMALIZATION: return DML_LP_NORMALIZATION_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_RNN: return DML_RNN_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_LSTM: return DML_LSTM_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_GRU: return DML_GRU_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ELEMENT_WISE_SIGN: return DML_ELEMENT_WISE_SIGN_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ELEMENT_WISE_IS_NAN: return DML_ELEMENT_WISE_IS_NAN_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ELEMENT_WISE_NEGATE: return DML_ELEMENT_WISE_NEGATE_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ELEMENT_WISE_ERF: return DML_ELEMENT_WISE_ERF_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ELEMENT_WISE_SINH: return DML_ELEMENT_WISE_SINH_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ELEMENT_WISE_COSH: return DML_ELEMENT_WISE_COSH_OPERATOR_SCHEMA;
|
||||
|
|
@ -1433,6 +1644,7 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType)
|
|||
case DML_OPERATOR_FILL_VALUE_CONSTANT: return DML_FILL_VALUE_CONSTANT_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_FILL_VALUE_SEQUENCE: return DML_FILL_VALUE_SEQUENCE_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_CUMULATIVE_SUMMATION: return DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_CUMULATIVE_PRODUCT: return DML_CUMULATIVE_PRODUCT_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_REVERSE_SUBSEQUENCES: return DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_GATHER_ELEMENTS: return DML_GATHER_ELEMENTS_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_GATHER_ND: return DML_GATHER_ND_OPERATOR_SCHEMA;
|
||||
|
|
@ -1462,7 +1674,12 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType)
|
|||
case DML_OPERATOR_SLICE_GRAD: return DML_SLICE_GRAD_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ADAM_OPTIMIZER: return DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ROI_ALIGN: return DML_ROI_ALIGN_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ROI_ALIGN1: return DML_ROI_ALIGN1_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_GATHER_ND1: return DML_GATHER_ND1_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR: return DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD: return DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ROI_ALIGN_GRAD: return DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING: return DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ACTIVATION_ELU: return DML_ACTIVATION_ELU_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ACTIVATION_CELU: return DML_ACTIVATION_CELU_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ACTIVATION_HARDMAX: return DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA;
|
||||
|
|
@ -1528,6 +1745,18 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
|
|||
return AbstractOperatorDesc(
|
||||
&DML_ELEMENT_WISE_CLIP_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_ELEMENT_WISE_CLIP_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_ELEMENT_WISE_CLIP1:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_ELEMENT_WISE_CLIP1_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_ELEMENT_WISE_CLIP1_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_ELEMENT_WISE_COS:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_ELEMENT_WISE_COS_OPERATOR_SCHEMA,
|
||||
|
|
@ -1620,6 +1849,14 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
|
|||
return AbstractOperatorDesc(
|
||||
&DML_ELEMENT_WISE_SQRT_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_ELEMENT_WISE_SQRT_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_ELEMENT_WISE_DIFFERENCE_SQUARE_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_ELEMENT_WISE_DIFFERENCE_SQUARE_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_ELEMENT_WISE_ATAN_YX:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_ELEMENT_WISE_ATAN_YX_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_ELEMENT_WISE_ATAN_YX_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_ELEMENT_WISE_SUBTRACT:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_ELEMENT_WISE_SUBTRACT_OPERATOR_SCHEMA,
|
||||
|
|
@ -1700,6 +1937,10 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
|
|||
return AbstractOperatorDesc(
|
||||
&DML_PADDING_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_PADDING_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_PADDING1:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_PADDING1_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_PADDING1_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_VALUE_SCALE_2D:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_VALUE_SCALE_2D_OPERATOR_SCHEMA,
|
||||
|
|
@ -1732,6 +1973,14 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
|
|||
return AbstractOperatorDesc(
|
||||
&DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_BATCH_NORMALIZATION_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_BATCH_NORMALIZATION_GRAD:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_BATCH_NORMALIZATION_GRAD_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_SCHEMA,
|
||||
|
|
@ -1740,6 +1989,10 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
|
|||
return AbstractOperatorDesc(
|
||||
&DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_LP_NORMALIZATION:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_LP_NORMALIZATION_OPERATOR_SCHEMA,
|
||||
|
|
@ -1764,6 +2017,10 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
|
|||
return AbstractOperatorDesc(
|
||||
&DML_ELEMENT_WISE_IS_NAN_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_ELEMENT_WISE_IS_NAN_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_ELEMENT_WISE_NEGATE:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_ELEMENT_WISE_NEGATE_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_ELEMENT_WISE_NEGATE_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_ELEMENT_WISE_ERF:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_ELEMENT_WISE_ERF_OPERATOR_SCHEMA,
|
||||
|
|
@ -1856,6 +2113,10 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
|
|||
return AbstractOperatorDesc(
|
||||
&DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_CUMULATIVE_SUMMATION_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_CUMULATIVE_PRODUCT:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_CUMULATIVE_PRODUCT_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_CUMULATIVE_PRODUCT_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_REVERSE_SUBSEQUENCES:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA,
|
||||
|
|
@ -1972,10 +2233,30 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
|
|||
return AbstractOperatorDesc(
|
||||
&DML_ROI_ALIGN_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_ROI_ALIGN_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_ROI_ALIGN1:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_ROI_ALIGN1_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_ROI_ALIGN1_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_GATHER_ND1:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_GATHER_ND1_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_GATHER_ND1_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_ROI_ALIGN_GRAD:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_ROI_ALIGN_GRAD_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_ACTIVATION_ELU:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_ACTIVATION_ELU_OPERATOR_SCHEMA,
|
||||
|
|
|
|||
|
|
@ -139,29 +139,6 @@ namespace Dml::GraphDescBuilder
|
|||
&graphNodeInfo
|
||||
);
|
||||
|
||||
// Determine the number of valid inputs and outputs of this node. The graph currently supports opererators
|
||||
// with unused inputs and outputs only at the end of each list.
|
||||
uint32_t validOpInputCount = 0;
|
||||
uint32_t validOpOutputCount = 0;
|
||||
|
||||
for (uint32_t i = 0; i < graphNodeInfo.kernelInputIndices.size(); ++i)
|
||||
{
|
||||
if (graphNodeInfo.kernelInputIndices[i] != std::numeric_limits<uint32_t>::max())
|
||||
{
|
||||
assert(i - validOpInputCount == 0);
|
||||
++validOpInputCount;
|
||||
}
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < graphNodeInfo.kernelOutputIndices.size(); ++i)
|
||||
{
|
||||
if (graphNodeInfo.kernelOutputIndices[i] != std::numeric_limits<uint32_t>::max())
|
||||
{
|
||||
assert(i - validOpOutputCount == 0);
|
||||
++validOpOutputCount;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t nodeIndex = gsl::narrow_cast<uint32_t>(graphNodes.size());
|
||||
AbstractOperatorDesc opDesc = *graphNodeInfo.desc; // Make a copy
|
||||
|
||||
|
|
@ -171,8 +148,13 @@ namespace Dml::GraphDescBuilder
|
|||
std::vector<DmlBufferTensorDesc*> outputTensorDescs = opDesc.GetOutputTensors();
|
||||
|
||||
// Set connections of the new node
|
||||
for (uint32_t inputIndex = 0; inputIndex < validOpInputCount; ++inputIndex)
|
||||
for (uint32_t inputIndex = 0; inputIndex < graphNodeInfo.kernelInputIndices.size(); ++inputIndex)
|
||||
{
|
||||
if (graphNodeInfo.kernelInputIndices[inputIndex] == std::numeric_limits<uint32_t>::max())
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
uint32_t kernelInputIndex = graphNodeInfo.kernelInputIndices[inputIndex];
|
||||
|
||||
const onnxruntime::NodeArg* arg = node.InputDefs()[kernelInputIndex];
|
||||
|
|
@ -224,8 +206,13 @@ namespace Dml::GraphDescBuilder
|
|||
|
||||
// Store the new node for lookup when downstream nodes consume it.
|
||||
|
||||
for (uint32_t outputIndex = 0; outputIndex < validOpOutputCount; ++outputIndex)
|
||||
for (uint32_t outputIndex = 0; outputIndex < graphNodeInfo.kernelOutputIndices.size(); ++outputIndex)
|
||||
{
|
||||
if (graphNodeInfo.kernelOutputIndices[outputIndex] == std::numeric_limits<uint32_t>::max())
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
uint32_t kernelOutputIndex = graphNodeInfo.kernelOutputIndices[outputIndex];
|
||||
const onnxruntime::NodeArg* arg = node.OutputDefs()[kernelOutputIndex];
|
||||
if (arg->Exists())
|
||||
|
|
|
|||
|
|
@ -139,7 +139,11 @@ namespace Dml
|
|||
}
|
||||
};
|
||||
|
||||
bool NodeArgSupportedInGraph(const onnxruntime::NodeArg* arg, uint32_t supportedDeviceDataTypeMask)
|
||||
bool NodeArgSupportedInGraph(
|
||||
const onnxruntime::NodeArg* arg,
|
||||
bool supports64BitTensorsViaEmulation,
|
||||
uint32_t supportedDeviceDataTypeMask
|
||||
)
|
||||
{
|
||||
if (arg->Exists())
|
||||
{
|
||||
|
|
@ -154,8 +158,11 @@ namespace Dml
|
|||
|
||||
MLOperatorTensorDataType mlDataType = ToMLTensorDataType(static_cast<onnx::TensorProto_DataType>(tensorType.elem_type()));
|
||||
|
||||
if (mlDataType == MLOperatorTensorDataType::UInt64 ||
|
||||
mlDataType == MLOperatorTensorDataType::Int64)
|
||||
// Do not include operators in the graph if tensor types are unsupported,
|
||||
// except cases that are always supported via emulation.
|
||||
if ((mlDataType == MLOperatorTensorDataType::UInt64 ||
|
||||
mlDataType == MLOperatorTensorDataType::Int64) &&
|
||||
!supports64BitTensorsViaEmulation)
|
||||
{
|
||||
constexpr uint32_t deviceDataTypeMask64bit = (1 << DML_TENSOR_DATA_TYPE_UINT64) | (1 << DML_TENSOR_DATA_TYPE_INT64);
|
||||
if ((supportedDeviceDataTypeMask & deviceDataTypeMask64bit) != deviceDataTypeMask64bit)
|
||||
|
|
@ -181,6 +188,7 @@ namespace Dml
|
|||
if (!isConstantCpuInput &&
|
||||
!NodeArgSupportedInGraph(
|
||||
node.InputDefs()[i],
|
||||
registration.support64BitTensorsViaEmulation,
|
||||
supportedDeviceDataTypeMask
|
||||
))
|
||||
{
|
||||
|
|
@ -192,6 +200,7 @@ namespace Dml
|
|||
{
|
||||
if (!NodeArgSupportedInGraph(
|
||||
arg,
|
||||
registration.support64BitTensorsViaEmulation,
|
||||
supportedDeviceDataTypeMask
|
||||
))
|
||||
{
|
||||
|
|
@ -234,6 +243,7 @@ namespace Dml
|
|||
ORT_THROW_HR_IF(E_INVALIDARG, allow64BitInputThroughStrides && !nodeNameToPartitionMap);
|
||||
|
||||
bool prefer64BitTensorsDirectly = false;
|
||||
bool support64BitTensorsViaEmulation = false;
|
||||
bool supportedWith64BitTensorsVia32BitStrides = false;
|
||||
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = false;
|
||||
std::vector<onnxruntime::NodeArg const*> constantCpuInputs;
|
||||
|
|
@ -244,9 +254,10 @@ namespace Dml
|
|||
// to 32-bit tensors via strides. If the caller passes allow64BitInputThroughStrides = false
|
||||
// in this particular call, then the operator-specific flags do not matter as the caller has
|
||||
// disabled 64-bit support.
|
||||
prefer64BitTensorsDirectly = regInfo->prefer64BitTensorsDirectly;
|
||||
support64BitTensorsViaEmulation = regInfo->support64BitTensorsViaEmulation;
|
||||
if (allow64BitInputThroughStrides)
|
||||
{
|
||||
prefer64BitTensorsDirectly = regInfo->prefer64BitTensorsDirectly;
|
||||
supportedWith64BitTensorsVia32BitStridesFromAnyEp = regInfo->supportedWith64BitTensorsVia32BitStridesFromAnyEp;
|
||||
supportedWith64BitTensorsVia32BitStrides = regInfo->supportedWith64BitTensorsVia32BitStrides | supportedWith64BitTensorsVia32BitStridesFromAnyEp;
|
||||
}
|
||||
|
|
@ -320,31 +331,44 @@ namespace Dml
|
|||
// operator, graph input or initializer, it's not safe to assume the input
|
||||
// can be represented with 32 bits.
|
||||
//
|
||||
bool isDataTypeSupported = (1 << dmlElementType) & supportedDeviceDataTypeMask;
|
||||
bool is64BitIntType = (dmlElementType == DML_TENSOR_DATA_TYPE_UINT64 || dmlElementType == DML_TENSOR_DATA_TYPE_INT64);
|
||||
bool needsFallbackTo32Bit = !prefer64BitTensorsDirectly || !((1 << dmlElementType) & supportedDeviceDataTypeMask);
|
||||
if (is64BitIntType && supportedWith64BitTensorsVia32BitStrides && needsFallbackTo32Bit)
|
||||
if (is64BitIntType)
|
||||
{
|
||||
dmlElementType = Remap64bitDmlDataTypeTo32bit(dmlElementType);
|
||||
|
||||
if (isInput && !supportedWith64BitTensorsVia32BitStridesFromAnyEp)
|
||||
if (support64BitTensorsViaEmulation)
|
||||
{
|
||||
// Look up the input partition. If it's a graph input or initializer it will be missing
|
||||
// from the partition map.
|
||||
const std::string& argName = nodeArg.Name();
|
||||
// Consider it supported regardless of hardware support.
|
||||
isDataTypeSupported = true;
|
||||
}
|
||||
else if (prefer64BitTensorsDirectly && isDataTypeSupported)
|
||||
{
|
||||
// Operator supports native int64/uint64 tensors.
|
||||
}
|
||||
else if (supportedWith64BitTensorsVia32BitStrides || supportedWith64BitTensorsVia32BitStridesFromAnyEp)
|
||||
{
|
||||
dmlElementType = Remap64bitDmlDataTypeTo32bit(dmlElementType);
|
||||
isDataTypeSupported = (1 << dmlElementType) & supportedDeviceDataTypeMask;
|
||||
|
||||
// If input tensor's data comes from the output of a different execution provider,
|
||||
// consider it unsafe to apply fallback to.
|
||||
auto partitionIter = nodeNameToPartitionMap->find(argName);
|
||||
if (partitionIter == nodeNameToPartitionMap->end() || !partitionIter->second->IsDmlPartition())
|
||||
if (isInput && !supportedWith64BitTensorsVia32BitStridesFromAnyEp)
|
||||
{
|
||||
nodeContainsSupportedDataTypes = false;
|
||||
return;
|
||||
// Look up the input partition. If it's a graph input or initializer it will be missing
|
||||
// from the partition map.
|
||||
const std::string& argName = nodeArg.Name();
|
||||
|
||||
// If input tensor's data comes from the output of a different execution provider,
|
||||
// consider it unsafe to apply fallback to.
|
||||
auto partitionIter = nodeNameToPartitionMap->find(argName);
|
||||
if (partitionIter == nodeNameToPartitionMap->end() || !partitionIter->second->IsDmlPartition())
|
||||
{
|
||||
nodeContainsSupportedDataTypes = false;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reject node if the data type is unsupported by the device.
|
||||
if (!((1 << dmlElementType) & supportedDeviceDataTypeMask))
|
||||
if (!isDataTypeSupported)
|
||||
{
|
||||
nodeContainsSupportedDataTypes = false;
|
||||
return;
|
||||
|
|
|
|||
|
|
@ -22,6 +22,13 @@ public:
|
|||
std::vector<std::optional<uint32_t>> outputIndices = { 0 };
|
||||
Initialize(kernelCreationContext, inputIndices, outputIndices);
|
||||
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
|
||||
DML_FILL_VALUE_CONSTANT_OPERATOR_DESC operatorDesc = {};
|
||||
operatorDesc.OutputTensor = outputDescs.data();
|
||||
operatorDesc.ValueDataType = this->m_outputTensorDescs.front().GetDmlDataType();
|
||||
// operatorDesc.Value already zeroed.
|
||||
|
||||
// Read the tensor attribute for the output fill pattern.
|
||||
if (kernelCreationContext.HasAttribute(AttrName::Value, MLOperatorAttributeTypeTensor))
|
||||
{
|
||||
|
|
@ -40,15 +47,14 @@ public:
|
|||
ML_CHECK_VALID_ARGUMENT(elementCount == 1); // Expect exactly one element.
|
||||
const size_t rawDataByteSize = GetByteSizeFromMlDataType(wrappedValueTensor.GetTensorDataType());
|
||||
const std::byte* rawData = static_cast<const std::byte*>(valueTensor->GetData());
|
||||
valueBytes.assign(rawData, rawData + rawDataByteSize);
|
||||
|
||||
memcpy(operatorDesc.Value.Bytes, rawData, std::min(rawDataByteSize, sizeof(operatorDesc.Value.Bytes)));
|
||||
}
|
||||
// Else valueBytes is empty, and the default fill pattern is 0.
|
||||
}
|
||||
|
||||
void Compute(const MLOperatorKernelContext& kernelContext) override
|
||||
{
|
||||
std::vector<IMLOperatorTensor*> outputTensors = GetOutputTensorsForExecute(kernelContext);
|
||||
ORT_THROW_IF_FAILED(m_executionProvider->FillTensorWithPattern(outputTensors.front(), valueBytes));
|
||||
|
||||
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_FILL_VALUE_CONSTANT, &operatorDesc };
|
||||
SetDmlOperatorDesc(opDesc, kernelCreationContext);
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
|||
|
|
@ -7,30 +7,45 @@ namespace Dml
|
|||
{
|
||||
|
||||
class DmlOperatorDynamicQuantizeLinear : public DmlOperator
|
||||
{
|
||||
{
|
||||
enum DmlInputTensors {
|
||||
IN_A,
|
||||
};
|
||||
|
||||
enum DmlOutputTensors {
|
||||
OUT_Y,
|
||||
OUT_Y_SCALE,
|
||||
OUT_Y_ZERO_POINT
|
||||
};
|
||||
|
||||
public:
|
||||
using Self = DmlOperatorDynamicQuantizeLinear;
|
||||
|
||||
DmlOperatorDynamicQuantizeLinear(const MLOperatorKernelCreationContext& kernelCreationContext)
|
||||
: DmlOperator(kernelCreationContext)
|
||||
{
|
||||
#if 0 // TODO:NickFe - https://github.com/onnx/onnx/blob/master/docs/Operators.md#dynamicquantizelinear
|
||||
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 1);
|
||||
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 3);
|
||||
DmlOperator::Initialize(kernelCreationContext, inputIndices, outputIndices);
|
||||
|
||||
DmlOperator::Initialize(kernelCreationContext);
|
||||
|
||||
m_inputTensorDescs[IN_A] = CreateTensorDescFromInput(kernelCreationContext, 0/*A OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned);
|
||||
|
||||
m_outputTensorDescs[OUT_Y] = CreateTensorDescFromOutput(kernelCreationContext, 0/*Y OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned);
|
||||
m_outputTensorDescs[OUT_Y_SCALE] = CreateTensorDescFromOutput(kernelCreationContext, 1/*Y Scale OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned);
|
||||
m_outputTensorDescs[OUT_Y_ZERO_POINT] = CreateTensorDescFromOutput(kernelCreationContext, 2/*Y Zero point OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned);
|
||||
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
|
||||
DML_PLACEHOLDER_OPERATOR_DESC operatorDesc = {};
|
||||
operatorDesc.IndicesTensor = &inputDescs[0];
|
||||
operatorDesc.ValuesTensor = &inputDescs[1];
|
||||
operatorDesc.OutputTensor = outputDescs.data();
|
||||
operatorDesc.Axis = dmlAxis;
|
||||
DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC operatorDesc = {};
|
||||
operatorDesc.InputTensor = &inputDescs[IN_A];
|
||||
operatorDesc.OutputTensor = &outputDescs[OUT_Y];
|
||||
operatorDesc.OutputScaleTensor = &outputDescs[OUT_Y_SCALE];
|
||||
operatorDesc.OutputZeroPointTensor = &outputDescs[OUT_Y_ZERO_POINT];
|
||||
|
||||
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_PLACEHOLDER, &operatorDesc };
|
||||
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR, &operatorDesc };
|
||||
SetDmlOperatorDesc(opDesc, kernelCreationContext);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -442,28 +442,34 @@ public:
|
|||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
|
||||
float minValue = -FLT_MAX;
|
||||
float maxValue = FLT_MAX;
|
||||
if (kernelInfo.IsInputValid(1))
|
||||
{
|
||||
minValue = static_cast<float>(ReadScalarTensorCastToFloat64(kernelInfo.GetConstantInputTensor(1)));
|
||||
}
|
||||
if (kernelInfo.IsInputValid(2)) {
|
||||
maxValue = static_cast<float>(ReadScalarTensorCastToFloat64(kernelInfo.GetConstantInputTensor(2)));
|
||||
}
|
||||
|
||||
DML_ELEMENT_WISE_CLIP_OPERATOR_DESC opDesc = {};
|
||||
DML_ELEMENT_WISE_CLIP1_OPERATOR_DESC opDesc = {};
|
||||
opDesc.InputTensor = inputDescs.data();
|
||||
opDesc.OutputTensor = outputDescs.data();
|
||||
opDesc.Min = minValue;
|
||||
opDesc.Max = maxValue;
|
||||
// MinMaxDataType will always be equal to inputDataTensorDataType
|
||||
// Assigning minMaxDataType to inputDataTensorDataType because this field
|
||||
// has to be assigned even if program does not go through below conditional
|
||||
// logic for some corner test case
|
||||
// Same applies to min and max value.
|
||||
opDesc.MinMaxDataType = this->m_inputTensorDescs[0].GetDmlDataType();
|
||||
CastToClampedScalarUnion<double>(opDesc.MinMaxDataType, -DBL_MAX, /*out*/&opDesc.Min);
|
||||
CastToClampedScalarUnion<double>(opDesc.MinMaxDataType, DBL_MAX, /*out*/&opDesc.Max);
|
||||
|
||||
SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_CLIP, &opDesc}, kernelInfo);
|
||||
if (kernelInfo.IsInputValid(1))
|
||||
{
|
||||
ReadScalarTensorData(kernelInfo.GetConstantInputTensor(1), /*out*/ &opDesc.Min.Bytes, sizeof(opDesc.Min.Bytes));
|
||||
}
|
||||
if (kernelInfo.IsInputValid(2))
|
||||
{
|
||||
ReadScalarTensorData(kernelInfo.GetConstantInputTensor(2), /*out*/ &opDesc.Max.Bytes, sizeof(opDesc.Max.Bytes));
|
||||
}
|
||||
|
||||
SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_CLIP1, &opDesc}, kernelInfo);
|
||||
}
|
||||
};
|
||||
|
||||
// Same operator signature as 11. Only difference is new type support
|
||||
using DmlOperatorElementwiseClip12 = DmlOperatorElementwiseClip11;
|
||||
using DmlOperatorElementwiseClip13 = DmlOperatorElementwiseClip11;
|
||||
|
||||
class DmlOperatorElementwisePow : public DmlOperator
|
||||
{
|
||||
|
|
@ -692,7 +698,7 @@ DML_OP_DEFINE_CREATION_FUNCTION(Ceil, DmlOperatorElementwiseUnary<DM
|
|||
DML_OP_DEFINE_CREATION_FUNCTION(Floor, DmlOperatorElementwiseUnary<DML_ELEMENT_WISE_FLOOR_OPERATOR_DESC>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Not, DmlOperatorElementwiseUnary<DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_DESC>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Sign, DmlOperatorElementwiseUnary<DML_ELEMENT_WISE_SIGN_OPERATOR_DESC>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(IsNan, DmlOperatorElementwiseUnary<DML_ELEMENT_WISE_IS_NAN_OPERATOR_DESC>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(IsNaN, DmlOperatorElementwiseUnary<DML_ELEMENT_WISE_IS_NAN_OPERATOR_DESC>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Sinh, DmlOperatorElementwiseUnary<DML_ELEMENT_WISE_SINH_OPERATOR_DESC>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Cosh, DmlOperatorElementwiseUnary<DML_ELEMENT_WISE_COSH_OPERATOR_DESC>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Asinh, DmlOperatorElementwiseUnary<DML_ELEMENT_WISE_ASINH_OPERATOR_DESC>);
|
||||
|
|
@ -724,6 +730,7 @@ DML_OP_DEFINE_CREATION_FUNCTION(Mean, DmlOperatorElementwiseMean);
|
|||
DML_OP_DEFINE_CREATION_FUNCTION(Clip7, DmlOperatorElementwiseClip7);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Clip11, DmlOperatorElementwiseClip11);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Clip12, DmlOperatorElementwiseClip12);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Clip13, DmlOperatorElementwiseClip13);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Pow, DmlOperatorElementwisePow);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(QuantizeLinear, DmlOperatorElementwiseQLinear<DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_DESC>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(DequantizeLinear, DmlOperatorElementwiseQLinear<DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_DESC>);
|
||||
|
|
|
|||
|
|
@ -22,16 +22,11 @@ public:
|
|||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
|
||||
DML_SCALE_BIAS scaleBias = {};
|
||||
scaleBias.Scale = -1.0f;
|
||||
scaleBias.Bias = 0.0f;
|
||||
|
||||
DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC opDesc = {};
|
||||
DML_ELEMENT_WISE_NEGATE_OPERATOR_DESC opDesc = {};
|
||||
opDesc.InputTensor = inputDescs.data();
|
||||
opDesc.OutputTensor = outputDescs.data();
|
||||
opDesc.ScaleBias = &scaleBias;
|
||||
|
||||
SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_IDENTITY, &opDesc}, kernelInfo);
|
||||
SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_NEGATE, &opDesc}, kernelInfo);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -56,34 +56,40 @@ public:
|
|||
ML_INVALID_ARGUMENT("Unknown Pad mode attribute.");
|
||||
}
|
||||
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
|
||||
DML_PADDING1_OPERATOR_DESC paddingDesc = {};
|
||||
paddingDesc.InputTensor = inputDescs.data();
|
||||
paddingDesc.OutputTensor = outputDescs.data();
|
||||
paddingDesc.PaddingMode = mode;
|
||||
paddingDesc.DimensionCount = gsl::narrow_cast<uint32_t>(m_startPadding.size());
|
||||
paddingDesc.StartPadding = m_startPadding.data();
|
||||
paddingDesc.EndPadding = m_endPadding.data();
|
||||
// PaddingValueDataType will always be equal to inputDataTensorDataType
|
||||
// Assigning paddingValueDataType to inputDataTensorDataType because this field
|
||||
// has to be assigned even if program does not go through below conditional
|
||||
// logic for some corner test case (like opsetVersion >= 11, but no validInput at index 2)
|
||||
// Same applies to paddingValue.
|
||||
paddingDesc.PaddingValueDataType = this->m_inputTensorDescs[0].GetDmlDataType();
|
||||
CastToClampedScalarUnion<float>(paddingDesc.PaddingValueDataType, 0.0f, /*out*/&paddingDesc.PaddingValue);
|
||||
|
||||
// Read the constant value which can come from an attribute or tensor.
|
||||
float value = 0.0f;
|
||||
if (opsetVersion >= 11)
|
||||
{
|
||||
if (kernelInfo.IsInputValid(2))
|
||||
{
|
||||
auto valueTensor = kernelInfo.GetConstantInputTensor(2);
|
||||
value = static_cast<float>(ReadScalarTensorCastToFloat64(valueTensor));
|
||||
MLOperatorTensor constantPaddingValueTensor = kernelInfo.GetConstantInputTensor(2);
|
||||
ReadScalarTensorData(constantPaddingValueTensor, /*out*/ &paddingDesc.PaddingValue.Bytes, sizeof(paddingDesc.PaddingValue.Bytes));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
value = kernelInfo.GetOptionalAttribute<float>(AttrName::Value, 0.0f);
|
||||
auto value = kernelInfo.GetOptionalAttribute<float>(AttrName::Value, 0.0f);
|
||||
CastToClampedScalarUnion<float>(paddingDesc.PaddingValueDataType, value, /*out*/&paddingDesc.PaddingValue);
|
||||
}
|
||||
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
|
||||
DML_PADDING_OPERATOR_DESC paddingDesc = {};
|
||||
paddingDesc.InputTensor = inputDescs.data();
|
||||
paddingDesc.OutputTensor = outputDescs.data();
|
||||
paddingDesc.PaddingMode = mode;
|
||||
paddingDesc.PaddingValue = value;
|
||||
paddingDesc.DimensionCount = gsl::narrow_cast<uint32_t>(m_startPadding.size());
|
||||
paddingDesc.StartPadding = m_startPadding.data();
|
||||
paddingDesc.EndPadding = m_endPadding.data();
|
||||
|
||||
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_PADDING, &paddingDesc };
|
||||
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_PADDING1, &paddingDesc };
|
||||
|
||||
SetDmlOperatorDesc(opDesc, kernelInfo);
|
||||
}
|
||||
|
|
@ -91,5 +97,6 @@ public:
|
|||
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Pad7, VersionedKernel<DmlOperatorPadding, 7>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Pad11, VersionedKernel<DmlOperatorPadding, 11>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Pad13, VersionedKernel<DmlOperatorPadding, 13>);
|
||||
|
||||
} // namespace Dml
|
||||
|
|
|
|||
|
|
@ -0,0 +1,69 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "precomp.h"
|
||||
|
||||
namespace Dml
|
||||
{
|
||||
|
||||
class DmlOperatorQLinearAdd : public DmlOperator
|
||||
{
|
||||
enum InputTensors {
|
||||
IN_A,
|
||||
IN_A_SCALE,
|
||||
IN_A_ZERO_POINT,
|
||||
IN_B,
|
||||
IN_B_SCALE,
|
||||
IN_B_ZERO_POINT,
|
||||
IN_C_SCALE,
|
||||
IN_C_ZERO_POINT
|
||||
};
|
||||
|
||||
public:
|
||||
DmlOperatorQLinearAdd(const MLOperatorKernelCreationContext& kernelInfo)
|
||||
: DmlOperator(kernelInfo)
|
||||
{
|
||||
DmlOperator::Initialize(kernelInfo);
|
||||
|
||||
std::vector<DimensionType> outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0);
|
||||
|
||||
uint32_t dmlDimSize = m_inputTensorDescs[0].GetDimensionCount();
|
||||
|
||||
// Initialize the input descriptions with broadcasting
|
||||
m_inputTensorDescs[IN_A] = CreateTensorDescFromInput(kernelInfo, 0/*A OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, outputShape);
|
||||
m_inputTensorDescs[IN_B] = CreateTensorDescFromInput(kernelInfo, 3/*B OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, outputShape);
|
||||
|
||||
m_inputTensorDescs[IN_A_SCALE] = CreateTensorDescFromInput(kernelInfo, 1/*A Scale OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, std::nullopt, dmlDimSize);
|
||||
m_inputTensorDescs[IN_A_ZERO_POINT] = CreateTensorDescFromInput(kernelInfo, 2/*A Zero point OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, std::nullopt, dmlDimSize);
|
||||
|
||||
m_inputTensorDescs[IN_B_SCALE] = CreateTensorDescFromInput(kernelInfo, 4/*B Scale OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, std::nullopt, dmlDimSize);
|
||||
m_inputTensorDescs[IN_B_ZERO_POINT] = CreateTensorDescFromInput(kernelInfo, 5/*B Zero point OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, std::nullopt, dmlDimSize);
|
||||
|
||||
m_inputTensorDescs[IN_C_SCALE] = CreateTensorDescFromInput(kernelInfo, 6/*C Zero point OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, std::nullopt, dmlDimSize);
|
||||
m_inputTensorDescs[IN_C_ZERO_POINT] = CreateTensorDescFromInput(kernelInfo, 7/*C Zero point OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, std::nullopt, dmlDimSize);
|
||||
|
||||
// Initialize the output description while overriding the shape
|
||||
m_outputTensorDescs[0] = CreateTensorDescFromOutput(kernelInfo, 0, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, outputShape);
|
||||
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
|
||||
DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_DESC AddDesc = {};
|
||||
AddDesc.ATensor = &inputDescs[IN_A];
|
||||
AddDesc.AScaleTensor = &inputDescs[IN_A_SCALE];
|
||||
AddDesc.AZeroPointTensor = inputDescs[IN_A_ZERO_POINT].Desc != nullptr ? &inputDescs[IN_A_ZERO_POINT] : nullptr;
|
||||
AddDesc.BTensor = &inputDescs[IN_B];
|
||||
AddDesc.BScaleTensor = &inputDescs[IN_B_SCALE];
|
||||
AddDesc.BZeroPointTensor = inputDescs[IN_B_ZERO_POINT].Desc != nullptr ? &inputDescs[IN_B_ZERO_POINT] : nullptr;
|
||||
AddDesc.OutputScaleTensor = &inputDescs[IN_C_SCALE];
|
||||
AddDesc.OutputZeroPointTensor = inputDescs[IN_C_ZERO_POINT].Desc != nullptr ? &inputDescs[IN_C_ZERO_POINT] : nullptr;
|
||||
AddDesc.OutputTensor = &outputDescs[0];
|
||||
|
||||
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD, &AddDesc };
|
||||
SetDmlOperatorDesc(opDesc, kernelInfo);
|
||||
}
|
||||
};
|
||||
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(QLinearAdd, DmlOperatorQLinearAdd);
|
||||
|
||||
} // namespace Dml
|
||||
|
|
@ -341,5 +341,6 @@ DML_OP_DEFINE_CREATION_FUNCTION(Resize11, VersionedKernel<DmlOperatorResize, 11>
|
|||
DML_OP_DEFINE_CREATION_FUNCTION(Upsample7, VersionedKernel<DmlOperatorResize, 7>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Upsample9, VersionedKernel<DmlOperatorResize, 9>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Upsample10, VersionedKernel<DmlOperatorResize, 10>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Upsample13, VersionedKernel<DmlOperatorResize, 13>);
|
||||
|
||||
} // namespace Dml
|
||||
|
|
|
|||
|
|
@ -113,6 +113,7 @@ public:
|
|||
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Scatter9, DmlOperatorScatter);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Scatter11, DmlOperatorScatter);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Scatter13, DmlOperatorScatter);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(ScatterElements, DmlOperatorScatter);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(ScatterND, DmlOperatorScatterNd);
|
||||
|
||||
|
|
|
|||
|
|
@ -55,4 +55,5 @@ void CALLBACK QuerySlice(IMLOperatorSupportQueryContextPrivate* context, bool* i
|
|||
DML_OP_DEFINE_CREATION_FUNCTION(Slice7, VersionedKernel<DmlOperatorSlice, 7> );
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Slice10, VersionedKernel<DmlOperatorSlice, 10>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Slice11, VersionedKernel<DmlOperatorSlice, 11>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Slice13, VersionedKernel<DmlOperatorSlice, 13>);
|
||||
} // namespace Dml
|
||||
|
|
|
|||
|
|
@ -35,14 +35,18 @@ enum class SupportedTensorDataTypes : uint32_t
|
|||
Complex64 = 1<<14,
|
||||
Complex128 = 1<<15,
|
||||
Ints8to32 = UInt8|Int8|UInt16|Int16|UInt32|Int32,
|
||||
Int32to64 = UInt32|Int32|UInt64|Int64,
|
||||
Float16to32 = Float16|Float32, // Float64 is not supported by DirectML.
|
||||
NumericDefault = Ints8to32|Float16to32,
|
||||
Ints32to64 = UInt32|Int32|UInt64|Int64,
|
||||
Ints8to64 = UInt8|Int8|UInt16|Int16|UInt32|Int32|UInt64|Int64,
|
||||
UInt8to64 = UInt8|UInt16|UInt32|UInt64,
|
||||
Float16to32 = Float16|Float32,
|
||||
Float16to64 = Float16|Float32|Float64,
|
||||
NumericDefault = Ints8to32|Float16to32, // Only simple numbers, not bool, complex, or string.
|
||||
Scalars8to32 = UInt8|Int8|UInt16|Int16|UInt32|Int32|Float16to32|Bool,
|
||||
AllScalars = UInt8|Int8|UInt16|Int16|UInt32|Int32|UInt64|Int64|Float16to32|Bool,
|
||||
AllScalars = UInt8|Int8|UInt16|Int16|UInt32|Int32|UInt64|Int64|Float16|Float32|Float64|Bool,
|
||||
Ints8Bit = UInt8|Int8,
|
||||
Ints16Bit = UInt16|Int16,
|
||||
Ints32Bit = UInt32|Int32,
|
||||
Ints64Bit = UInt64|Int64,
|
||||
All = static_cast<uint32_t>(-1),
|
||||
};
|
||||
DEFINE_ENUM_FLAG_OPERATORS(Dml::SupportedTensorDataTypes);
|
||||
|
|
@ -54,6 +58,7 @@ enum class DmlGraphSupport : uint32_t
|
|||
SupportedWith64BitTensorsVia32BitStrides = 2, // Supports them via 32-bit tensors and doubled strides.
|
||||
SupportedWith64BitTensorsVia32BitStridesFromAnyEp = 4, // Supports input from any execution provider (otherwise only inputs from other DML nodes)
|
||||
Prefer64BitTensorsDirectly = 8, // Natively supports 64-bit tensors. So avoid strided 32-bit unless the device lacks support.
|
||||
Support64BitTensorsViaEmulation = 16,// supports int/uint64 tensors via emulation of 32-bit types.
|
||||
};
|
||||
DEFINE_ENUM_FLAG_OPERATORS(DmlGraphSupport);
|
||||
|
||||
|
|
@ -109,8 +114,10 @@ DML_OP_EXTERN_CREATION_FUNCTION(Concat);
|
|||
DML_OP_EXTERN_CREATION_FUNCTION(Slice7);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Slice10);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Slice11);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Slice13);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Pad7);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Pad11);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Pad13);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(SpaceToDepth);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(DepthToSpace);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Sqrt);
|
||||
|
|
@ -124,6 +131,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Floor);
|
|||
DML_OP_EXTERN_CREATION_FUNCTION(Clip7);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Clip11);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Clip12);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Clip13);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Greater);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Less);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(GreaterOrEqual);
|
||||
|
|
@ -161,6 +169,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(ImageScaler);
|
|||
DML_OP_EXTERN_CREATION_FUNCTION(Upsample7);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Upsample9);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Upsample10);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Upsample13);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Sigmoid);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(HardSigmoid);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Tanh);
|
||||
|
|
@ -206,7 +215,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(FusedSum);
|
|||
DML_OP_EXTERN_CREATION_FUNCTION(QuantizeLinear);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(DequantizeLinear);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Sign);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(IsNan);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(IsNaN);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Sinh);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Cosh);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Tanh);
|
||||
|
|
@ -221,6 +230,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(EyeLike);
|
|||
DML_OP_EXTERN_CREATION_FUNCTION(MaxUnpool);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Scatter9);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Scatter11);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Scatter13);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Resize10);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Resize11);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(ConstantOfShape);
|
||||
|
|
@ -235,6 +245,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(ReverseSequence);
|
|||
DML_OP_EXTERN_CREATION_FUNCTION(Round);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(ScatterElements);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(ScatterND);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(QLinearAdd);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(QLinearConv);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(QLinearMatMul);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(DynamicQuantizeLinear);
|
||||
|
|
@ -251,6 +262,7 @@ constexpr static std::array<const char*, 2> typeNameListTwo = { "T1", "T2" };
|
|||
constexpr static std::array<const char*, 3> typeNameListThree = { "T1", "T2", "T3" };
|
||||
constexpr static std::array<const char*, 4> typeNameListFour = { "T1", "T2", "T3", "T4" };
|
||||
constexpr static std::array<const char*, 2> typeNameListTopK = { "T", "I" };
|
||||
constexpr static std::array<const char*, 2> typeNameListMaxPool = { "T", "I" };
|
||||
constexpr static std::array<const char*, 2> typeNameListLogicalComparison = { "T", "T1" };
|
||||
constexpr static std::array<const char*, 2> typeNameListPow12 = {"T", "T1"};
|
||||
constexpr static std::array<const char*, 2> typeNameListConstantOfShape = { "T1", "T2" };
|
||||
|
|
@ -259,51 +271,65 @@ constexpr static std::array<const char*, 1> typeNameListScatterGatherND = { "T"
|
|||
constexpr static std::array<const char*, 2> typeNameListSlice10 = { "T", "Tind" };
|
||||
constexpr static std::array<const char*, 2> typeNameListWhere = { "B", "T" };
|
||||
constexpr static std::array<const char*, 1> typeNameListEyeLike = { "T2" };
|
||||
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListAll = {SupportedTensorDataTypes::All};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat32 = {SupportedTensorDataTypes::Float32};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat16to32 = {SupportedTensorDataTypes::Float16to32};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat16to32Int8 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat16to32Int32 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::UInt32};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat16to32Int8to32 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit | SupportedTensorDataTypes::Ints16Bit | SupportedTensorDataTypes::Ints32Bit};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListInt8to32 = {SupportedTensorDataTypes::Ints8to32};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListInt32to64AndFloat16to32 = {SupportedTensorDataTypes::Int32to64|SupportedTensorDataTypes::Float16to32};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat16to32Ints8 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat16to32Ints32 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::UInt32};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat16to32Ints8to32 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit | SupportedTensorDataTypes::Ints16Bit | SupportedTensorDataTypes::Ints32Bit};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat16to32Ints8to64 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit | SupportedTensorDataTypes::Ints16Bit | SupportedTensorDataTypes::Ints32Bit | SupportedTensorDataTypes::Ints64Bit};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat16to32Ints32to64 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints32Bit | SupportedTensorDataTypes::Ints64Bit};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListUInt8to64 = {SupportedTensorDataTypes::UInt8to64};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListNumericDefault = { SupportedTensorDataTypes::NumericDefault };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListAllScalars = { SupportedTensorDataTypes::AllScalars };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListBool = {SupportedTensorDataTypes::Bool};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListPow12 = {SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::NumericDefault};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListTopK = {SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Int64};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListTopK = {SupportedTensorDataTypes::NumericDefault | SupportedTensorDataTypes::Ints64Bit, SupportedTensorDataTypes::Int64};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListMaxPool = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit, SupportedTensorDataTypes::Int64};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListMaxUnpool = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int64};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListIndices = { SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64 };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListCast = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::AllScalars };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListScalars8to32 = { SupportedTensorDataTypes::Scalars8to32 };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListScatterGather = { SupportedTensorDataTypes::Scalars8to32, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListScatterGatherND = { SupportedTensorDataTypes::Scalars8to32 };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListSlice10 = { SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListScatterGather = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListScatterGatherND = { SupportedTensorDataTypes::AllScalars };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListSlice10 = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListQuantizeLinear = { SupportedTensorDataTypes::Float32 | SupportedTensorDataTypes::Int32, SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListDequantizeLinear = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 | SupportedTensorDataTypes::Int32 };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListQuantize = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::UInt8 };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListIsNan = { SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Bool };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListIsInf = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::Bool };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListConstantOfShape = { SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64, SupportedTensorDataTypes::Float16to32 };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListWhere = { SupportedTensorDataTypes::Bool, SupportedTensorDataTypes::AllScalars };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListOneHot = /* indices, depth, values */ { SupportedTensorDataTypes::Int32to64, SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Scalars8to32 };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListConstantOfShape = { SupportedTensorDataTypes::Int64, SupportedTensorDataTypes::AllScalars };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListWhere = { SupportedTensorDataTypes::Bool, SupportedTensorDataTypes::Scalars8to32 };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListOneHot = /* indices, depth, values */ { SupportedTensorDataTypes::Ints32to64, SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::AllScalars };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListLogicalComparison7 = /* A&B,C */ { SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Bool };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListLogicalComparison9 = /* A&B,C */ { SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Bool };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListSigned = { SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int16 | SupportedTensorDataTypes::Int8 };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListRange = {SupportedTensorDataTypes::Int16|SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Float32};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListLogicalComparison9 = /* A&B,C */ { SupportedTensorDataTypes::Float16to32|SupportedTensorDataTypes::Ints8to64, SupportedTensorDataTypes::Bool };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListSigned = { SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int64 | SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int16 | SupportedTensorDataTypes::Int8 };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListRange = {SupportedTensorDataTypes::Int16|SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64|SupportedTensorDataTypes::Float32};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListResize11 = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float16to32 /* ROI read by CPU */};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListInteger = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int32 };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListInteger8 = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8 };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListRoiAlign = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64 };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListArgMinMax = {SupportedTensorDataTypes::Float16to32|SupportedTensorDataTypes::Ints8to64};
|
||||
|
||||
constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListQLinearMatMul = {
|
||||
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
|
||||
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
|
||||
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8
|
||||
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
|
||||
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
|
||||
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8
|
||||
};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 4> supportedTypeListQLinearConv = {
|
||||
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
|
||||
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
|
||||
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
|
||||
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
|
||||
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
|
||||
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
|
||||
SupportedTensorDataTypes::Int32
|
||||
};
|
||||
|
||||
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListDynamicQuantizeLinear = {
|
||||
SupportedTensorDataTypes::Float32,
|
||||
SupportedTensorDataTypes::UInt8,
|
||||
};
|
||||
|
||||
template<typename... Args>
|
||||
constexpr auto requiredConstantCpuInputs(Args... args)
|
||||
{
|
||||
|
|
@ -336,249 +362,313 @@ constexpr auto requiredConstantCpuInputs(Args... args)
|
|||
|
||||
constexpr static OperatorRegistrationInformation operatorRegistrationInformationTable[] =
|
||||
{
|
||||
/// Domain/Type, Ver, Name, TypeNames, Types, Graph Support, Required const CPU inputs,
|
||||
/// Input count required for graph support,
|
||||
/// Support query function
|
||||
/// Domain/Type, Ver, Name, TypeNames, Types, Graph Support, Required const CPU inputs,
|
||||
/// Input count required for graph support,
|
||||
/// Support query function
|
||||
|
||||
// Deep Learning Standard Layers
|
||||
{REG_INFO( 7, Conv, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, Conv, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 10, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, GlobalAveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO( 8, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
|
||||
{REG_INFO( 10, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
|
||||
{REG_INFO( 11, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
|
||||
{REG_INFO( 12, MaxPool, typeNameListDefault, supportedTypeListFloat16to32Int8, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
|
||||
{REG_INFO( 7, Conv, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, Conv, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 10, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, GlobalAveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO( 8, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
|
||||
{REG_INFO( 10, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
|
||||
{REG_INFO( 11, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
|
||||
{REG_INFO( 12, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
|
||||
|
||||
{REG_INFO( 7, GlobalMaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, GlobalLpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, MaxRoiPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 10, RoiAlign, typeNameListTwo, supportedTypeListRoiAlign, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
|
||||
{REG_INFO( 7, InstanceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, // v9 just removes 'spatial' attribute.
|
||||
{REG_INFO( 7, LRN, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, MeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, MeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, LpNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, RNN, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::NotSupported)},
|
||||
{REG_INFO( 7, GRU, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::NotSupported)},
|
||||
{REG_INFO( 7, LSTM, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::NotSupported)},
|
||||
{REG_INFO_MS( 1, ConvTransposeWithDynamicPads, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(2))},
|
||||
{REG_INFO( 7, GlobalMaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, GlobalLpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, MaxRoiPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 10, RoiAlign, typeNameListTwo, supportedTypeListRoiAlign, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 7, InstanceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, // v9 just removes 'spatial' attribute.
|
||||
{REG_INFO( 7, LRN, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 13, LRN, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, MeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, MeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 13, MeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, LpNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, RNN, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::NotSupported)},
|
||||
{REG_INFO( 7, GRU, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::NotSupported)},
|
||||
{REG_INFO( 7, LSTM, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::NotSupported)},
|
||||
{REG_INFO_MS( 1, ConvTransposeWithDynamicPads, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(2))},
|
||||
|
||||
// Data Reorganization Layers
|
||||
{REG_INFO( 7, Split, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO( 11, Split, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, // Adds negative axis.
|
||||
{REG_INFO( 7, Transpose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO( 7, Concat, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO( 11, Concat, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, // Adds negative axis.
|
||||
{REG_INFO_VER( 7, Slice, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO_VER( 10, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)}, // Adds negative axes.
|
||||
{REG_INFO_VER( 11, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)},
|
||||
{REG_INFO_VER( 7, Pad, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO_VER( 11, Pad, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728
|
||||
{REG_INFO( 7, SpaceToDepth, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO( 7, DepthToSpace, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO( 11, DepthToSpace, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO( 7, Tile, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO( 8, Expand, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO( 9, ConstantOfShape, typeNameListConstantOfShape, supportedTypeListConstantOfShape, DmlGraphSupport::NotSupported, requiredConstantCpuInputs(0))},
|
||||
{REG_INFO( 7, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
|
||||
{REG_INFO( 11, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
|
||||
{REG_INFO( 11, GatherElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
|
||||
{REG_INFO( 11, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
|
||||
{REG_INFO( 12, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
|
||||
{REG_INFO_VER( 9, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
|
||||
{REG_INFO_VER( 11, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
|
||||
{REG_INFO( 11, ScatterElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
|
||||
{REG_INFO( 11, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
|
||||
{REG_INFO( 9, EyeLike, typeNameListEyeLike, supportedTypeListScalars8to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Split, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 11, Split, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, // Adds negative axis.
|
||||
{REG_INFO( 7, Transpose, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 13, Transpose, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 7, Concat, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 11, Concat, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, // Adds negative axis.
|
||||
{REG_INFO( 13, Concat, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, // Adds negative axis.
|
||||
{REG_INFO_VER( 7, Slice, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO_VER( 10, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)}, // Adds negative axes.
|
||||
{REG_INFO_VER( 11, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)},
|
||||
{REG_INFO_VER( 13, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)},
|
||||
{REG_INFO_VER( 7, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO_VER( 11, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728
|
||||
{REG_INFO_VER( 13, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728
|
||||
{REG_INFO( 7, SpaceToDepth, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 13, SpaceToDepth, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 7, DepthToSpace, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 11, DepthToSpace, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 13, DepthToSpace, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 7, Tile, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO( 13, Tile, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO( 8, Expand, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO( 13, Expand, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO( 9, ConstantOfShape, typeNameListConstantOfShape, supportedTypeListConstantOfShape, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(0))},
|
||||
{REG_INFO( 7, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 11, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 13, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 11, GatherElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 13, GatherElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 11, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 12, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 13, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO_VER( 9, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO_VER( 11, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO_VER( 13, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 11, ScatterElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 13, ScatterElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 11, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 13, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 9, EyeLike, typeNameListEyeLike, supportedTypeListScalars8to32, DmlGraphSupport::Supported)},
|
||||
|
||||
// Data reorganization that merely changes the dimensions while keeping the data identical.
|
||||
{REG_INFO_ID( 7, Identity, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO_ID( 7, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO_ID( 9, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO_ID( 11, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO_ID( 7, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO_ID( 11, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO_ID( 7, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO_ID( 11, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO_ID( 7, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO_ID( 7, Identity, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO_ID( 13, Identity, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO_ID( 7, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO_ID( 9, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO_ID( 11, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO_ID( 13, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO_ID( 7, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO_ID( 11, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO_ID( 7, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO_ID( 11, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO_ID( 7, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO_ID( 13, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1))},
|
||||
|
||||
// Elementwise
|
||||
{REG_INFO( 7, Sqrt, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Reciprocal, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Pow, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 12, Pow, typeNameListPow12, supportedTypeListPow12, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Exp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Log, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Abs, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Ceil, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Floor, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 7, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 11, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1,2))},
|
||||
{REG_INFO_VER( 12, Clip, typeNameListDefault, supportedTypeListFloat16to32Int8to32,DmlGraphSupport::Supported, requiredConstantCpuInputs(1,2))},
|
||||
{REG_INFO( 7, Add, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Sub, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Mul, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Div, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Sum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 8, Sum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 7, Mean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 8, Mean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 7, Max, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 8, Max, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 12, Max, typeNameListDefault, supportedTypeListFloat16to32Int8to32,DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 7, Min, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 8, Min, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 12, Min, typeNameListDefault, supportedTypeListFloat16to32Int8to32,DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 7, Cos, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Sin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Tan, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Acos, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Asin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Atan, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Affine, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 10, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 10, DequantizeLinear, typeNameListTwo, supportedTypeListDequantizeLinear, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MS( 1, QuantizeLinear, typeNameListTwo, supportedTypeListQuantize, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MS( 1, DequantizeLinear, typeNameListTwo, supportedTypeListQuantize, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Sign, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, IsNan, typeNameListTwo, supportedTypeListIsNan, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Sinh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Cosh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Asinh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Acosh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Atanh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Erf, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Where, typeNameListWhere, supportedTypeListWhere, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ReduceSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 12, Einsum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryEinSum )},
|
||||
{REG_INFO( 7, ReduceMean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceMean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ReduceProd, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceProd, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ReduceLogSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceLogSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ReduceLogSumExp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceLogSumExp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ReduceSumSquare, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceSumSquare, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ReduceL1, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceL1, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ReduceL2, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceL2, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 12, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32Int8, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 12, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32Int8, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ArgMax, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO( 11, ArgMax, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO( 12, ArgMax, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO( 7, ArgMin, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO( 11, ArgMin, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO( 12, ArgMin, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO( 7, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Neg, typeNameListDefault, supportedTypeListSigned, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Greater, typeNameListLogicalComparison, supportedTypeListLogicalComparison7,DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Greater, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Less, typeNameListLogicalComparison, supportedTypeListLogicalComparison7,DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Less, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 12, GreaterOrEqual, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 12, LessOrEqual, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison7,DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Not, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, And, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Or, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Xor, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Sqrt, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 13, Sqrt, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Reciprocal, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 13, Reciprocal, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Pow, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 12, Pow, typeNameListPow12, supportedTypeListPow12, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 13, Pow, typeNameListPow12, supportedTypeListPow12, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Exp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 13, Exp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Log, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 13, Log, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Abs, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 13, Abs, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Ceil, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 13, Ceil, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Floor, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 13, Floor, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 7, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 11, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1,2))},
|
||||
{REG_INFO_VER( 12, Clip, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1,2))},
|
||||
{REG_INFO_VER( 13, Clip, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1,2))},
|
||||
{REG_INFO( 7, Add, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 13, Add, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 7, Sub, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 13, Sub, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 7, Mul, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 13, Mul, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 7, Div, typeNameListDefault, supportedTypeListFloat16to32Ints32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 13, Div, typeNameListDefault, supportedTypeListFloat16to32Ints32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Sum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 8, Sum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 13, Sum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 7, Mean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 8, Mean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 13, Mean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 7, Max, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 8, Max, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 12, Max, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 13, Max, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 7, Min, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 8, Min, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 12, Min, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 13, Min, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 7, Cos, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Sin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Tan, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Acos, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Asin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Atan, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Affine, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 10, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 10, DequantizeLinear, typeNameListTwo, supportedTypeListDequantizeLinear, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MS( 1, QuantizeLinear, typeNameListTwo, supportedTypeListQuantize, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MS( 1, DequantizeLinear, typeNameListTwo, supportedTypeListQuantize, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Sign, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 13, Sign, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, IsNaN, typeNameListTwo, supportedTypeListIsNan, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 13, IsNaN, typeNameListTwo, supportedTypeListIsNan, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Sinh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Cosh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Asinh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Acosh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Atanh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Erf, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 13, Erf, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Where, typeNameListWhere, supportedTypeListWhere, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ReduceSum, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceSum, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 12, Einsum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryEinSum )},
|
||||
{REG_INFO( 7, ReduceMean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceMean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 13, ReduceMean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ReduceProd, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceProd, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 13, ReduceProd, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ReduceLogSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceLogSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 13, ReduceLogSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ReduceLogSumExp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceLogSumExp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 13, ReduceLogSumExp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ReduceSumSquare, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceSumSquare, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 13, ReduceSumSquare, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ReduceL1, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceL1, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 13, ReduceL1, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ReduceL2, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceL2, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 13, ReduceL2, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 12, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 13, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 7, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 12, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 13, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 7, ArgMax, typeNameListDefault, supportedTypeListArgMinMax, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 11, ArgMax, typeNameListDefault, supportedTypeListArgMinMax, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 12, ArgMax, typeNameListDefault, supportedTypeListArgMinMax, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 13, ArgMax, typeNameListDefault, supportedTypeListArgMinMax, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 7, ArgMin, typeNameListDefault, supportedTypeListArgMinMax, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 11, ArgMin, typeNameListDefault, supportedTypeListArgMinMax, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 12, ArgMin, typeNameListDefault, supportedTypeListArgMinMax, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 13, ArgMin, typeNameListDefault, supportedTypeListArgMinMax, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 7, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 13, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Neg, typeNameListDefault, supportedTypeListSigned, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 13, Neg, typeNameListDefault, supportedTypeListSigned, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Greater, typeNameListLogicalComparison, supportedTypeListLogicalComparison7, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 9, Greater, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 13, Greater, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 7, Less, typeNameListLogicalComparison, supportedTypeListLogicalComparison7, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 9, Less, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 13, Less, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 12, GreaterOrEqual, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 12, LessOrEqual, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 7, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison7, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 11, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 13, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 7, Not, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, And, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Or, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Xor, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)},
|
||||
|
||||
// Imaging Operators
|
||||
{REG_INFO( 7, Crop, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ImageScaler, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 7, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 9, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)},
|
||||
{REG_INFO_VER( 10, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)},
|
||||
{REG_INFO_VER( 10, Resize, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)},
|
||||
{REG_INFO_VER( 11, Resize, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*roi, scales, sizes*/, std::nullopt, QueryResize)},
|
||||
{REG_INFO( 7, Crop, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ImageScaler, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 7, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 9, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)},
|
||||
{REG_INFO_VER( 10, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)},
|
||||
{REG_INFO_VER( 13, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)},
|
||||
{REG_INFO_VER( 10, Resize, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)},
|
||||
{REG_INFO_VER( 11, Resize, typeNameListTwo, supportedTypeListResize11, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*roi, scales, sizes*/, std::nullopt, QueryResize)},
|
||||
|
||||
// Activation Functions
|
||||
{REG_INFO( 7, Sigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, HardSigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Tanh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ScaledTanh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Relu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, LeakyRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, PRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, PRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 10, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Elu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 12, Celu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Selu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Softmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, Softmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, LogSoftmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, LogSoftmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Hardmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, Hardmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Softsign, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Softplus, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ParametricSoftplus, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Dropout, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Shrink, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Sigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 13, Sigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, HardSigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Tanh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 13, Tanh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ScaledTanh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Relu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 13, Relu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, LeakyRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, PRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, PRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 10, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Elu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 12, Celu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Selu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Softmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, Softmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, LogSoftmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, LogSoftmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Hardmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, Hardmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Softsign, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Softplus, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ParametricSoftplus, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Dropout, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Shrink, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},
|
||||
|
||||
// Uncategorized
|
||||
{REG_INFO( 7, MatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, MatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, MatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, MatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 13, MatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 9, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 13, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 7, MemcpyFromHost, typeNameListDefault, supportedTypeListAll)},
|
||||
{REG_INFO( 7, MemcpyToHost, typeNameListDefault, supportedTypeListAll)},
|
||||
{REG_INFO_VER( 7, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO_VER( 10, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO_VER( 11, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO( 9, OneHot, typeNameListThree, supportedTypeListOneHot, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO( 11, OneHot, typeNameListThree, supportedTypeListOneHot, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO_VER( 7, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO_VER( 10, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO_VER( 11, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO( 9, OneHot, typeNameListThree, supportedTypeListOneHot, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO( 11, OneHot, typeNameListThree, supportedTypeListOneHot, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1))},
|
||||
|
||||
// Fused operators
|
||||
{REG_INFO_MSDML(1, FusedConv, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MSDML(1, FusedConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MSDML(1, FusedInstanceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MSDML(1, FusedBatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MSDML(1, FusedMeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MSDML(1, FusedGemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MSDML(1, FusedMatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MSDML(1, FusedAdd, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MSDML(1, FusedSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO_MSDML(1, FusedConv, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MSDML(1, FusedConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MSDML(1, FusedInstanceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MSDML(1, FusedBatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MSDML(1, FusedMeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MSDML(1, FusedGemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MSDML(1, FusedMatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MSDML(1, FusedAdd, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MSDML(1, FusedSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
|
||||
{REG_INFO( 10, IsInf, typeNameListTwo, supportedTypeListIsInf, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 10, Mod, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, BitShift, typeNameListDefault, supportedTypeListInt8to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, Round, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 10, ReverseSequence, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
|
||||
{REG_INFO( 11, CumSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO( 11, Range, typeNameListDefault, supportedTypeListRange, DmlGraphSupport::Supported, requiredConstantCpuInputs(0,1,2))},
|
||||
{REG_INFO( 10, IsInf, typeNameListTwo, supportedTypeListIsInf, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 10, Mod, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 13, Mod, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, BitShift, typeNameListDefault, supportedTypeListUInt8to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 11, Round, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 10, ReverseSequence, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
|
||||
{REG_INFO( 11, CumSum, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO( 11, Range, typeNameListDefault, supportedTypeListRange, DmlGraphSupport::Supported, requiredConstantCpuInputs(0,1,2))},
|
||||
|
||||
{REG_INFO( 9, MaxUnpool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(2))},
|
||||
{REG_INFO( 11, MaxUnpool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(2))}, // 11 is identical to 9.
|
||||
|
||||
{REG_INFO( 10, QLinearConv, typeNameListFour, supportedTypeListQLinearConv, DmlGraphSupport::NotSupported)},
|
||||
{REG_INFO( 10, QLinearMatMul, typeNameListThree, supportedTypeListQLinearMatMul, DmlGraphSupport::NotSupported)},
|
||||
{REG_INFO( 10, MatMulInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::NotSupported)},
|
||||
{REG_INFO( 10, ConvInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::NotSupported)},
|
||||
{REG_INFO( 9, MaxUnpool, typeNameListTwo, supportedTypeListMaxUnpool, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(2))},
|
||||
{REG_INFO( 11, MaxUnpool, typeNameListTwo, supportedTypeListMaxUnpool, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(2))}, // 11 is identical to 9.
|
||||
|
||||
{REG_INFO_MS( 1, QLinearAdd, typeNameListDefault, supportedTypeListInteger8, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 10, QLinearConv, typeNameListFour, supportedTypeListQLinearConv, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 10, QLinearMatMul, typeNameListThree, supportedTypeListQLinearMatMul, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 10, MatMulInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 10, ConvInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, DynamicQuantizeLinear, typeNameListTwo, supportedTypeListDynamicQuantizeLinear, DmlGraphSupport::Supported)},
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
|
|
@ -607,6 +697,7 @@ void RegisterDmlOperators(IMLOperatorRegistry* registry)
|
|||
// The graph must be configured with operators from only the legacy DML API, or only the new DML API
|
||||
bool kernelSupportsGraph = !bool(information.dmlGraphSupport & DmlGraphSupport::NotSupported);
|
||||
bool prefer64BitTensorsDirectly = bool(information.dmlGraphSupport & DmlGraphSupport::Prefer64BitTensorsDirectly);
|
||||
bool support64BitTensorsViaEmulation = bool(information.dmlGraphSupport & DmlGraphSupport::Support64BitTensorsViaEmulation);
|
||||
bool supportedWith64BitTensorsVia32BitStrides = bool(information.dmlGraphSupport & DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides);
|
||||
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = bool(information.dmlGraphSupport & DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp);
|
||||
|
||||
|
|
@ -692,6 +783,7 @@ void RegisterDmlOperators(IMLOperatorRegistry* registry)
|
|||
supportedWith64BitTensorsVia32BitStrides,
|
||||
supportedWith64BitTensorsVia32BitStridesFromAnyEp,
|
||||
prefer64BitTensorsDirectly,
|
||||
support64BitTensorsViaEmulation,
|
||||
information.requiredConstantCpuInputs.first.data(),
|
||||
static_cast<uint32_t>(information.requiredConstantCpuInputs.second)
|
||||
));
|
||||
|
|
|
|||
|
|
@ -151,15 +151,20 @@ namespace Dml
|
|||
OperatorInfo{ "InstanceNormalization", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_InstanceNormalization },
|
||||
OperatorInfo{ "MeanVarianceNormalization", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_MeanVarianceNormalization },
|
||||
OperatorInfo{ "MeanVarianceNormalization", onnxruntime::kOnnxDomain, OnnxOperatorSet9::sc_sinceVer_MeanVarianceNormalization },
|
||||
OperatorInfo{ "MeanVarianceNormalization", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_MeanVarianceNormalization },
|
||||
OperatorInfo{ "Gemm", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Gemm },
|
||||
OperatorInfo{ "Gemm", onnxruntime::kOnnxDomain, OnnxOperatorSet9::sc_sinceVer_Gemm },
|
||||
OperatorInfo{ "Gemm", onnxruntime::kOnnxDomain, OnnxOperatorSet11::sc_sinceVer_Gemm },
|
||||
OperatorInfo{ "Gemm", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Gemm },
|
||||
OperatorInfo{ "MatMul", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_MatMul },
|
||||
OperatorInfo{ "MatMul", onnxruntime::kOnnxDomain, OnnxOperatorSet9::sc_sinceVer_MatMul },
|
||||
OperatorInfo{ "MatMul", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_MatMul },
|
||||
|
||||
// The filter for activation functions maps to what DML's fused op internally fuses at the shader level.
|
||||
OperatorInfo{ "Add", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Add, {"Relu", "LeakyRelu"} },
|
||||
OperatorInfo{ "Add", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Add, {"Relu", "LeakyRelu"} },
|
||||
OperatorInfo{ "Sum", onnxruntime::kOnnxDomain, OnnxOperatorSet8::sc_sinceVer_Sum, {"Relu", "LeakyRelu"}, 2 },
|
||||
OperatorInfo{ "Sum", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Sum, {"Relu", "LeakyRelu"}, 2 },
|
||||
};
|
||||
|
||||
// Not all activations can be fused - only simple elementwise activations (i.e. activation functions which
|
||||
|
|
@ -167,10 +172,13 @@ namespace Dml
|
|||
static const OperatorInfo c_activationOps[] =
|
||||
{
|
||||
OperatorInfo{ "Sigmoid", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Sigmoid },
|
||||
OperatorInfo{ "Sigmoid", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Sigmoid },
|
||||
OperatorInfo{ "HardSigmoid", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_HardSigmoid },
|
||||
OperatorInfo{ "Tanh", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Tanh },
|
||||
OperatorInfo{ "Tanh", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Tanh },
|
||||
OperatorInfo{ "ScaledTanh", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_ScaledTanh },
|
||||
OperatorInfo{ "Relu", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Relu },
|
||||
OperatorInfo{ "Relu", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Relu },
|
||||
OperatorInfo{ "LeakyRelu", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_LeakyRelu },
|
||||
OperatorInfo{ "PRelu", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_PRelu },
|
||||
OperatorInfo{ "PRelu", onnxruntime::kOnnxDomain, OnnxOperatorSet9::sc_sinceVer_PRelu },
|
||||
|
|
|
|||
|
|
@ -4,27 +4,47 @@
|
|||
#pragma once
|
||||
|
||||
#define ML_CHECK_VALID_ARGUMENT(x, ...)\
|
||||
{\
|
||||
if ((x) == false) {\
|
||||
ORT_THROW_HR(E_INVALIDARG);\
|
||||
}\
|
||||
}
|
||||
{\
|
||||
if ((x) == false)\
|
||||
{\
|
||||
ORT_THROW_HR(E_INVALIDARG);\
|
||||
}\
|
||||
}
|
||||
|
||||
#define ML_INVALID_ARGUMENT(msg)\
|
||||
ORT_THROW_HR(E_INVALIDARG);\
|
||||
ORT_THROW_HR(E_INVALIDARG);\
|
||||
|
||||
#define ML_CHECK_HRESULT(hr, ...)\
|
||||
{\
|
||||
if (FAILED(hr)) {\
|
||||
ORT_THROW_HR(E_INVALIDARG);\
|
||||
}\
|
||||
}
|
||||
{\
|
||||
if (FAILED(hr))\
|
||||
{\
|
||||
ORT_THROW_HR(E_INVALIDARG);\
|
||||
}\
|
||||
}
|
||||
|
||||
namespace OperatorHelper
|
||||
{
|
||||
template<typename T, typename I> T clamp_cast(I input)
|
||||
// Clamp the input value to the maximum range of the output, before casting to the output type.
|
||||
//
|
||||
// e.g. int32(300) would yield int8(255) rather than int8(44).
|
||||
// float32(-42) would yield uint32(0) rather than a huge positive number.
|
||||
template<typename OutputType, typename InputType> OutputType clamp_cast(InputType input)
|
||||
{
|
||||
return static_cast<T>(std::clamp<I>(input, std::numeric_limits<T>::lowest(), std::numeric_limits<T>::max()));
|
||||
// Determine the larger type to decide which numeric limits to clamp to.
|
||||
using InputLimits = std::numeric_limits<InputType>;
|
||||
using OutputLimits = std::numeric_limits<OutputType>;
|
||||
constexpr int inputMaxDigits = std::max(InputLimits::max_exponent, InputLimits::digits);
|
||||
constexpr int outputMaxDigits = std::max(OutputLimits::max_exponent, OutputLimits::digits);
|
||||
constexpr bool isEitherTypeUnsigned = std::is_unsigned_v<InputType> || std::is_unsigned_v<OutputType>;
|
||||
constexpr bool isOutputTypeLarger = outputMaxDigits > inputMaxDigits;
|
||||
|
||||
InputType lowestValue = isEitherTypeUnsigned ? static_cast<InputType>(0) :
|
||||
isOutputTypeLarger ? InputLimits::lowest() :
|
||||
static_cast<InputType>(OutputLimits::lowest());
|
||||
InputType highestValue = isOutputTypeLarger ? InputLimits::max() :
|
||||
static_cast<InputType>(OutputLimits::max());
|
||||
|
||||
return static_cast<OutputType>(std::clamp<InputType>(input, lowestValue, highestValue));
|
||||
}
|
||||
enum TensorAxis { N, C, H, W, DoNotCoerce = INT_MAX, LeftAligned = INT_MAX, RightAligned = INT_MIN, NoPlacementAdjustment = 0 };
|
||||
enum BroadcastMode { NoBroadcast, UnidirectionalBroadcast, MultidirectionalBroadcast };
|
||||
|
|
|
|||
|
|
@ -114,6 +114,7 @@ IMLOperatorRegistryPrivate : public IUnknown
|
|||
bool supportedWith64BitTensorsVia32BitStrides = false,
|
||||
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = false,
|
||||
bool prefer64BitTensorsDirectly = false,
|
||||
bool support64BitTensorsViaEmulation = false,
|
||||
_In_reads_(constantCpuInputCount) const uint32_t* constantCpuInputs = nullptr,
|
||||
uint32_t constantCpuInputCount = 0
|
||||
) const noexcept PURE;
|
||||
|
|
|
|||
|
|
@ -32,87 +32,6 @@ namespace OperatorHelper
|
|||
}
|
||||
}
|
||||
|
||||
void ReadCpuLocalTensorIntoInt32(
|
||||
const MLOperatorTensor& tensor,
|
||||
std::vector<int32_t>& result
|
||||
)
|
||||
{
|
||||
result.clear();
|
||||
ML_CHECK_VALID_ARGUMENT(tensor.IsCpuData(), "Tensor must be CPU Tensor.");
|
||||
|
||||
const std::vector<uint32_t>& tensorDimensions = tensor.GetShape();
|
||||
const uint32_t elementCount = ComputeElementCountFromDimensions(tensorDimensions);
|
||||
|
||||
switch (tensor.GetTensorDataType())
|
||||
{
|
||||
case MLOperatorTensorDataType::Int32:
|
||||
{
|
||||
const int32_t* data = tensor.GetData<int32_t>();
|
||||
result.assign(data, data + elementCount);
|
||||
}
|
||||
break;
|
||||
|
||||
case MLOperatorTensorDataType::Int64:
|
||||
{
|
||||
const int64_t* data = tensor.GetData<int64_t>();
|
||||
result.reserve(elementCount);
|
||||
|
||||
// Use clamped cast rather than static_cast/narrow_cast,
|
||||
// because it's not uncommon for a model to specify a
|
||||
// 64-bit INTMAX constant as a sentinel value to mean
|
||||
// the largest possible value (even though the actual
|
||||
// dimension values come nowhere close to that, far
|
||||
// less than 32-bit INTMAX).
|
||||
for (auto d : gsl::make_span(data, data + elementCount))
|
||||
{
|
||||
result.push_back(clamp_cast<int32_t>(d));
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
default:
|
||||
ML_INVALID_ARGUMENT("Expecting CPU local tensor of type int32 or int64.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void ReadCpuLocalTensorIntoFloat32(
|
||||
const MLOperatorTensor& tensor,
|
||||
std::vector<float>& result
|
||||
)
|
||||
{
|
||||
result.clear();
|
||||
ML_CHECK_VALID_ARGUMENT(tensor.IsCpuData(), "Tensor must be CPU Tensor.");
|
||||
|
||||
const std::vector<uint32_t>& tensorDimensions = tensor.GetShape();
|
||||
const uint32_t elementCount = ComputeElementCountFromDimensions(tensorDimensions);
|
||||
|
||||
switch (tensor.GetTensorDataType())
|
||||
{
|
||||
case MLOperatorTensorDataType::Float:
|
||||
{
|
||||
const float* data = tensor.GetData<float>();
|
||||
result.assign(data, data + elementCount);
|
||||
}
|
||||
break;
|
||||
|
||||
default:
|
||||
ML_INVALID_ARGUMENT("Expecting CPU local tensor of type float32.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void DowncastDimensions(gsl::span<const int64_t> inputDimensions, std::vector<DimensionType>& outputDimensions)
|
||||
{
|
||||
outputDimensions.reserve(inputDimensions.size());
|
||||
outputDimensions.clear();
|
||||
|
||||
for (int64_t dim : inputDimensions)
|
||||
{
|
||||
outputDimensions.push_back(gsl::narrow_cast<uint32_t>(std::clamp<int64_t>(dim, INT32_MIN, INT32_MAX)));
|
||||
}
|
||||
}
|
||||
|
||||
float CastFloat16ToFloat32(uint16_t input)
|
||||
{
|
||||
// Promote float16m10e5s1 to float32m23e8s1.
|
||||
|
|
@ -201,6 +120,130 @@ namespace OperatorHelper
|
|||
}
|
||||
#pragma warning(pop)
|
||||
|
||||
void ReadCpuLocalTensorIntoInt32(
|
||||
const MLOperatorTensor& tensor,
|
||||
std::vector<int32_t>& result
|
||||
)
|
||||
{
|
||||
result.clear();
|
||||
ML_CHECK_VALID_ARGUMENT(tensor.IsCpuData(), "Tensor must be CPU Tensor.");
|
||||
|
||||
const std::vector<uint32_t>& tensorDimensions = tensor.GetShape();
|
||||
const uint32_t elementCount = ComputeElementCountFromDimensions(tensorDimensions);
|
||||
|
||||
switch (tensor.GetTensorDataType())
|
||||
{
|
||||
case MLOperatorTensorDataType::Int32:
|
||||
{
|
||||
const int32_t* data = tensor.GetData<int32_t>();
|
||||
result.assign(data, data + elementCount);
|
||||
}
|
||||
break;
|
||||
|
||||
case MLOperatorTensorDataType::Int64:
|
||||
{
|
||||
const int64_t* data = tensor.GetData<int64_t>();
|
||||
result.reserve(elementCount);
|
||||
|
||||
// Use clamped cast rather than static_cast/narrow_cast,
|
||||
// because it's not uncommon for a model to specify a
|
||||
// 64-bit INTMAX constant as a sentinel value to mean
|
||||
// the largest possible value (even though the actual
|
||||
// dimension values come nowhere close to that, far
|
||||
// less than 32-bit INTMAX).
|
||||
for (auto d : gsl::make_span(data, data + elementCount))
|
||||
{
|
||||
result.push_back(clamp_cast<int32_t>(d));
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
default:
|
||||
ML_INVALID_ARGUMENT("Expecting CPU local tensor of type int32 or int64.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void ReadCpuLocalTensorIntoFloat32(
|
||||
const MLOperatorTensor& tensor,
|
||||
std::vector<float>& result
|
||||
)
|
||||
{
|
||||
result.clear();
|
||||
ML_CHECK_VALID_ARGUMENT(tensor.IsCpuData(), "Tensor must be CPU Tensor.");
|
||||
|
||||
const std::vector<uint32_t>& tensorDimensions = tensor.GetShape();
|
||||
const uint32_t elementCount = ComputeElementCountFromDimensions(tensorDimensions);
|
||||
result.resize(elementCount);
|
||||
|
||||
switch (tensor.GetTensorDataType())
|
||||
{
|
||||
case MLOperatorTensorDataType::Float16:
|
||||
{
|
||||
const uint16_t* data = reinterpret_cast<const uint16_t*>(tensor.GetByteData());
|
||||
std::transform(data, data + elementCount, result.begin(), CastFloat16ToFloat32);
|
||||
}
|
||||
break;
|
||||
|
||||
case MLOperatorTensorDataType::/*Float32*/Float:
|
||||
{
|
||||
const float* data = tensor.GetData<float>();
|
||||
result.assign(data, data + elementCount);
|
||||
}
|
||||
break;
|
||||
|
||||
case MLOperatorTensorDataType::/*Float64*/Double:
|
||||
{
|
||||
const double* data = tensor.GetData<double>();
|
||||
std::transform(data, data + elementCount, result.begin(), [](auto v) {return static_cast<float>(v); });
|
||||
}
|
||||
break;
|
||||
|
||||
case MLOperatorTensorDataType::Int32:
|
||||
{
|
||||
const int32_t* data = tensor.GetData<int32_t>();
|
||||
std::transform(data, data + elementCount, result.begin(), [](auto v) {return static_cast<float>(v); });
|
||||
}
|
||||
break;
|
||||
|
||||
case MLOperatorTensorDataType::UInt32:
|
||||
{
|
||||
const uint32_t* data = tensor.GetData<uint32_t>();
|
||||
std::transform(data, data + elementCount, result.begin(), [](auto v) {return static_cast<float>(v); });
|
||||
}
|
||||
break;
|
||||
|
||||
case MLOperatorTensorDataType::Int64:
|
||||
{
|
||||
const int64_t* data = tensor.GetData<int64_t>();
|
||||
std::transform(data, data + elementCount, result.begin(), [](auto v) {return static_cast<float>(v); });
|
||||
}
|
||||
break;
|
||||
|
||||
case MLOperatorTensorDataType::UInt64:
|
||||
{
|
||||
const uint64_t* data = tensor.GetData<uint64_t>();
|
||||
std::transform(data, data + elementCount, result.begin(), [](auto v) {return static_cast<float>(v); });
|
||||
}
|
||||
break;
|
||||
|
||||
default:
|
||||
ML_INVALID_ARGUMENT("Expecting CPU local tensor of type float32.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void DowncastDimensions(gsl::span<const int64_t> inputDimensions, std::vector<DimensionType>& outputDimensions)
|
||||
{
|
||||
outputDimensions.reserve(inputDimensions.size());
|
||||
outputDimensions.clear();
|
||||
|
||||
for (int64_t dim : inputDimensions)
|
||||
{
|
||||
outputDimensions.push_back(gsl::narrow_cast<uint32_t>(std::clamp<int64_t>(dim, INT32_MIN, INT32_MAX)));
|
||||
}
|
||||
}
|
||||
|
||||
int64_t IsFloatDataType(MLOperatorTensorDataType tensorDataType)
|
||||
{
|
||||
switch (tensorDataType)
|
||||
|
|
@ -1089,7 +1132,7 @@ namespace OperatorHelper
|
|||
ML_CHECK_VALID_ARGUMENT(inputCount + 1 == m_components.size(), "Mismatch between input tensor count and string equation component count.");
|
||||
ML_CHECK_VALID_ARGUMENT(outputCount == 1, "EinSum expects exactly 1 output tensor.");
|
||||
|
||||
std::vector<uint32_t> labelSizes(m_labelIndices.size(), static_cast<uint32_t>(INT_MIN));
|
||||
std::vector<uint32_t> labelSizes(m_labelIndices.size(), UINT_MAX);
|
||||
|
||||
// Read every input tensor, comparing labels to ensure consistent sizes from the equation parsed earlier.
|
||||
for (uint32_t i = 0; i < inputCount; ++i)
|
||||
|
|
@ -1110,7 +1153,7 @@ namespace OperatorHelper
|
|||
uint32_t labelIndex = labelIndices[j];
|
||||
assert(labelIndex < labelSizes.size());
|
||||
|
||||
if (labelSizes[labelIndex] == INT_MIN)
|
||||
if (labelSizes[labelIndex] == UINT_MAX)
|
||||
{
|
||||
labelSizes[labelIndex] = dimensionSize;
|
||||
}
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -142,7 +142,7 @@ namespace OperatorHelper
|
|||
namespace OnnxOperatorSet9
|
||||
{
|
||||
static const int sc_sinceVer_Sign = 9;
|
||||
static const int sc_sinceVer_IsNan = 9;
|
||||
static const int sc_sinceVer_IsNaN = 9;
|
||||
static const int sc_sinceVer_Sinh = 9;
|
||||
static const int sc_sinceVer_Cosh = 9;
|
||||
static const int sc_sinceVer_Asinh = 9;
|
||||
|
|
@ -233,6 +233,7 @@ namespace OperatorHelper
|
|||
static const int sc_sinceVer_ReduceSumSquare = 11;
|
||||
static const int sc_sinceVer_Resize = 11;
|
||||
static const int sc_sinceVer_Round = 11;
|
||||
static const int sc_sinceVer_DynamicQuantizeLinear = 11;
|
||||
static const int sc_sinceVer_Scan = 11;
|
||||
static const int sc_sinceVer_Scatter = 11; // Deprecated alias
|
||||
static const int sc_sinceVer_ScatterElements = 11;
|
||||
|
|
@ -263,6 +264,73 @@ namespace OperatorHelper
|
|||
static const int sc_sinceVer_ReduceMin = 12;
|
||||
} // namespace OnnxOperatorSet12
|
||||
|
||||
namespace OnnxOperatorSet13
|
||||
{
|
||||
static const int sc_sinceVer_Abs = 13;
|
||||
static const int sc_sinceVer_Add = 13;
|
||||
static const int sc_sinceVer_ArgMax = 13;
|
||||
static const int sc_sinceVer_ArgMin = 13;
|
||||
static const int sc_sinceVer_Cast = 13;
|
||||
static const int sc_sinceVer_Ceil = 13;
|
||||
static const int sc_sinceVer_Clip = 13;
|
||||
static const int sc_sinceVer_Concat = 13;
|
||||
static const int sc_sinceVer_Constant = 13;
|
||||
static const int sc_sinceVer_DepthToSpace = 13;
|
||||
static const int sc_sinceVer_Div = 13;
|
||||
static const int sc_sinceVer_Equal = 13;
|
||||
static const int sc_sinceVer_Erf = 13;
|
||||
static const int sc_sinceVer_Exp = 13;
|
||||
static const int sc_sinceVer_Expand = 13;
|
||||
static const int sc_sinceVer_Flatten = 13;
|
||||
static const int sc_sinceVer_Floor = 13;
|
||||
static const int sc_sinceVer_Gather = 13;
|
||||
static const int sc_sinceVer_GatherElements = 13;
|
||||
static const int sc_sinceVer_GatherND = 13;
|
||||
static const int sc_sinceVer_Gemm = 13;
|
||||
static const int sc_sinceVer_Greater = 13;
|
||||
static const int sc_sinceVer_Identity = 13;
|
||||
static const int sc_sinceVer_IsNaN = 13;
|
||||
static const int sc_sinceVer_LRN = 13;
|
||||
static const int sc_sinceVer_Less = 13;
|
||||
static const int sc_sinceVer_Log = 13;
|
||||
static const int sc_sinceVer_MatMul = 13;
|
||||
static const int sc_sinceVer_Max = 13;
|
||||
static const int sc_sinceVer_Mean = 13;
|
||||
static const int sc_sinceVer_MeanVarianceNormalization = 13;
|
||||
static const int sc_sinceVer_Min = 13;
|
||||
static const int sc_sinceVer_Mod = 13;
|
||||
static const int sc_sinceVer_Mul = 13;
|
||||
static const int sc_sinceVer_Neg = 13;
|
||||
static const int sc_sinceVer_Pad = 13;
|
||||
static const int sc_sinceVer_Pow = 13;
|
||||
static const int sc_sinceVer_Reciprocal = 13;
|
||||
static const int sc_sinceVer_ReduceL1 = 13;
|
||||
static const int sc_sinceVer_ReduceL2 = 13;
|
||||
static const int sc_sinceVer_ReduceLogSum = 13;
|
||||
static const int sc_sinceVer_ReduceLogSumExp = 13;
|
||||
static const int sc_sinceVer_ReduceMax = 13;
|
||||
static const int sc_sinceVer_ReduceMean = 13;
|
||||
static const int sc_sinceVer_ReduceMin = 13;
|
||||
static const int sc_sinceVer_ReduceProd = 13;
|
||||
static const int sc_sinceVer_ReduceSumSquare = 13;
|
||||
static const int sc_sinceVer_Relu = 13;
|
||||
static const int sc_sinceVer_Reshape = 13;
|
||||
static const int sc_sinceVer_Scatter = 13;
|
||||
static const int sc_sinceVer_ScatterElements = 13;
|
||||
static const int sc_sinceVer_ScatterND = 13;
|
||||
static const int sc_sinceVer_Sigmoid = 13;
|
||||
static const int sc_sinceVer_Sign = 13;
|
||||
static const int sc_sinceVer_Slice = 13;
|
||||
static const int sc_sinceVer_SpaceToDepth = 13;
|
||||
static const int sc_sinceVer_Sqrt = 13;
|
||||
static const int sc_sinceVer_Sub = 13;
|
||||
static const int sc_sinceVer_Sum = 13;
|
||||
static const int sc_sinceVer_Tanh = 13;
|
||||
static const int sc_sinceVer_Tile = 13;
|
||||
static const int sc_sinceVer_Transpose = 13;
|
||||
static const int sc_sinceVer_Upsample = 13;
|
||||
} // namespace OnnxOperatorSet13
|
||||
|
||||
namespace MsftOperatorSet1
|
||||
{
|
||||
static const int sc_sinceVer_FusedConv = 1;
|
||||
|
|
@ -277,6 +345,7 @@ namespace OperatorHelper
|
|||
static const int sc_sinceVer_QuantizeLinear = 1;
|
||||
static const int sc_sinceVer_DequantizeLinear = 1;
|
||||
static const int sc_sinceVer_ConvTransposeWithDynamicPads = 1;
|
||||
static const int sc_sinceVer_QLinearAdd = 1;
|
||||
} // namespace MsftOperatorSet1
|
||||
|
||||
} // namespace OperatorHelper
|
||||
|
|
|
|||
|
|
@ -58,6 +58,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistryImpl::RegisterOperatorKernel(
|
|||
bool supportedWith64BitTensorsVia32BitStrides,
|
||||
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp,
|
||||
bool prefer64BitTensorsDirectly,
|
||||
bool support64BitTensorsViaEmulation,
|
||||
_In_reads_(constantCpuInputCount) const uint32_t* requiredConstantCpuInputs,
|
||||
uint32_t constantCpuInputCount) const noexcept try {
|
||||
#ifdef LAYERING_DONE
|
||||
|
|
@ -83,6 +84,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistryImpl::RegisterOperatorKernel(
|
|||
supportedWith64BitTensorsVia32BitStrides,
|
||||
supportedWith64BitTensorsVia32BitStridesFromAnyEp,
|
||||
prefer64BitTensorsDirectly,
|
||||
support64BitTensorsViaEmulation,
|
||||
requiredConstantCpuInputs,
|
||||
constantCpuInputCount);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ class AbiCustomRegistryImpl : public AbiCustomRegistry {
|
|||
bool supports_64bit_directly = false,
|
||||
bool allows_64bit_via_strides = false,
|
||||
bool allows_64bit_via_strides_from_any_ep = false,
|
||||
bool supports_64bit_tensors_via_emulation = false,
|
||||
_In_reads_(constant_cpu_input_count) const uint32_t* required_constant_cpu_inputs = nullptr,
|
||||
uint32_t constant_cpu_input_count = 0) const noexcept override;
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue