Merge branch 'DmlDev' into user/dwayner/DmlEpGatherScatterReverseRangeInfModRoundBitshiftCumSumClip

This commit is contained in:
Dwayne Robinson 2020-03-27 18:54:30 -07:00
commit 351c3c30fb
11 changed files with 623 additions and 117 deletions

View file

@ -25,7 +25,7 @@ struct EnumTraits<DML_TENSOR_TYPE>
template <>
struct EnumTraits<DML_OPERATOR_TYPE>
{
static constexpr auto ValueCount = 110;
static constexpr auto ValueCount = 119;
static constexpr size_t ActivationFunctionCount = 19;
};
@ -96,6 +96,12 @@ struct EnumTraits<DML_IS_INFINITY_MODE>
static constexpr auto ValueCount = 3;
};
template <>
struct EnumTraits<DML_DEPTH_SPACE_ORDER>
{
static constexpr auto ValueCount = 2;
};
template <>
struct EnumTraits<DML_AXIS_DIRECTION>
{
@ -706,6 +712,60 @@ struct OperatorDescTraits<DML_SCATTER_ND_OPERATOR_DESC>
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_SCATTER_ND;
};
template <>
struct OperatorDescTraits<DML_MAX_POOLING2_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MAX_POOLING2;
};
template <>
struct OperatorDescTraits<DML_SLICE1_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_SLICE1;
};
template <>
struct OperatorDescTraits<DML_TOP_K1_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_TOP_K1;
};
template <>
struct OperatorDescTraits<DML_DEPTH_TO_SPACE1_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_DEPTH_TO_SPACE1;
};
template <>
struct OperatorDescTraits<DML_SPACE_TO_DEPTH1_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_SPACE_TO_DEPTH1;
};
template <>
struct OperatorDescTraits<DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1;
};
template <>
struct OperatorDescTraits<DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY;
};
template <>
struct OperatorDescTraits<DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION;
};
template <>
struct OperatorDescTraits<DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR;
};
template <>
struct OperatorDescTraits<DML_ACTIVATION_ELU_OPERATOR_DESC>
{
@ -1366,6 +1426,60 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SCATTER_ND>
using DescType = DML_SCATTER_ND_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MAX_POOLING2>
{
using DescType = DML_MAX_POOLING2_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SLICE1>
{
using DescType = DML_SLICE1_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_TOP_K1>
{
using DescType = DML_TOP_K1_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_DEPTH_TO_SPACE1>
{
using DescType = DML_DEPTH_TO_SPACE1_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SPACE_TO_DEPTH1>
{
using DescType = DML_SPACE_TO_DEPTH1_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1>
{
using DescType = DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY>
{
using DescType = DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION>
{
using DescType = DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR>
{
using DescType = DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_ELU>
{
@ -1673,6 +1787,24 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
return std::invoke(std::forward<Visitor>(visitor), DML_GATHER_ND_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_SCATTER_ND:
return std::invoke(std::forward<Visitor>(visitor), DML_SCATTER_ND_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_MAX_POOLING2:
return std::invoke(std::forward<Visitor>(visitor), DML_MAX_POOLING2_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_SLICE1:
return std::invoke(std::forward<Visitor>(visitor), DML_SLICE1_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_TOP_K1:
return std::invoke(std::forward<Visitor>(visitor), DML_TOP_K1_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_DEPTH_TO_SPACE1:
return std::invoke(std::forward<Visitor>(visitor), DML_DEPTH_TO_SPACE1_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_SPACE_TO_DEPTH1:
return std::invoke(std::forward<Visitor>(visitor), DML_SPACE_TO_DEPTH1_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1:
return std::invoke(std::forward<Visitor>(visitor), DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY:
return std::invoke(std::forward<Visitor>(visitor), DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION:
return std::invoke(std::forward<Visitor>(visitor), DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR:
return std::invoke(std::forward<Visitor>(visitor), DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ACTIVATION_ELU:
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_ELU_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ACTIVATION_HARDMAX:
@ -1813,6 +1945,15 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value)
case DML_OPERATOR_GATHER_ELEMENTS: return "DML_OPERATOR_GATHER_ELEMENTS";
case DML_OPERATOR_GATHER_ND: return "DML_OPERATOR_GATHER_ND";
case DML_OPERATOR_SCATTER_ND: return "DML_OPERATOR_SCATTER_ND";
case DML_OPERATOR_MAX_POOLING2: return "DML_OPERATOR_MAX_POOLING2";
case DML_OPERATOR_SLICE1: return "DML_OPERATOR_SLICE1";
case DML_OPERATOR_TOP_K1: return "DML_OPERATOR_TOP_K1";
case DML_OPERATOR_DEPTH_TO_SPACE1: return "DML_OPERATOR_DEPTH_TO_SPACE1";
case DML_OPERATOR_SPACE_TO_DEPTH1: return "DML_OPERATOR_SPACE_TO_DEPTH1";
case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1: return "DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1";
case DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY: return "DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY";
case DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION: return "DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION";
case DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR: return "DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR";
default:
assert(false);
return "<unknown>";

View file

@ -1441,6 +1441,172 @@ constexpr DML_OPERATOR_SCHEMA DML_SCATTER_ND_OPERATOR_SCHEMA {
DML_SCATTER_ND_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_MAX_POOLING2_OPERATOR_SCHEMA_FIELDS[9] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputIndicesTensor", true },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false },
};
constexpr DML_OPERATOR_SCHEMA DML_MAX_POOLING2_OPERATOR_SCHEMA {
"DML_OPERATOR_MAX_POOLING2",
DML_OPERATOR_MAX_POOLING2,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
9,
DML_MAX_POOLING2_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_SLICE1_OPERATOR_SCHEMA_FIELDS[6] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "InputWindowOffsets", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "InputWindowSizes", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_INT_ARRAY, "InputWindowStrides", false },
};
constexpr DML_OPERATOR_SCHEMA DML_SLICE1_OPERATOR_SCHEMA {
"DML_OPERATOR_SLICE1",
DML_OPERATOR_SLICE1,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
6,
DML_SLICE1_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_TOP_K1_OPERATOR_SCHEMA_FIELDS[6] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputValueTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputIndexTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Axis", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "K", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AxisDirection", false },
};
constexpr DML_OPERATOR_SCHEMA DML_TOP_K1_OPERATOR_SCHEMA {
"DML_OPERATOR_TOP_K1",
DML_OPERATOR_TOP_K1,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
6,
DML_TOP_K1_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_DEPTH_TO_SPACE1_OPERATOR_SCHEMA_FIELDS[4] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "BlockSize", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Order", false },
};
constexpr DML_OPERATOR_SCHEMA DML_DEPTH_TO_SPACE1_OPERATOR_SCHEMA {
"DML_OPERATOR_DEPTH_TO_SPACE1",
DML_OPERATOR_DEPTH_TO_SPACE1,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
4,
DML_DEPTH_TO_SPACE1_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_SPACE_TO_DEPTH1_OPERATOR_SCHEMA_FIELDS[4] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "BlockSize", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Order", false },
};
constexpr DML_OPERATOR_SCHEMA DML_SPACE_TO_DEPTH1_OPERATOR_SCHEMA {
"DML_OPERATOR_SPACE_TO_DEPTH1",
DML_OPERATOR_SPACE_TO_DEPTH1,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
4,
DML_SPACE_TO_DEPTH1_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA_FIELDS[9] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ScaleTensor", true },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AxisCount", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Axes", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "NormalizeVariance", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Epsilon", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC, "FusedActivation", true },
};
constexpr DML_OPERATOR_SCHEMA DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA {
"DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1",
DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
9,
DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA_FIELDS[9] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AScaleTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AZeroPointTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BScaleTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BZeroPointTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputZeroPointTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
};
constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA {
"DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY",
DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
9,
DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA_FIELDS[16] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputScaleTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputZeroPointTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "FilterTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "FilterScaleTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "FilterZeroPointTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputZeroPointTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "GroupCount", false },
};
constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA {
"DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION",
DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
16,
DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA_FIELDS[4] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputZeroPointTensor", false },
};
constexpr DML_OPERATOR_SCHEMA DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA {
"DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR",
DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
4,
DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_ACTIVATION_ELU_OPERATOR_SCHEMA_FIELDS[3] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },

View file

@ -852,6 +852,118 @@ inline std::vector<OperatorField> GetFields(const DML_SCATTER_ND_OPERATOR_DESC&
OperatorField(&DML_SCATTER_ND_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<UINT>(desc.IndicesDimensionCount))),
};
}
inline std::vector<OperatorField> GetFields(const DML_MAX_POOLING2_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_MAX_POOLING2_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_MAX_POOLING2_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
OperatorField(&DML_MAX_POOLING2_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputIndicesTensor))),
OperatorField(&DML_MAX_POOLING2_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<UINT>(desc.DimensionCount))),
OperatorField(&DML_MAX_POOLING2_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const UINT*>(desc.Strides), desc.DimensionCount)),
OperatorField(&DML_MAX_POOLING2_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<const UINT*>(desc.WindowSize), desc.DimensionCount)),
OperatorField(&DML_MAX_POOLING2_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<const UINT*>(desc.StartPadding), desc.DimensionCount)),
OperatorField(&DML_MAX_POOLING2_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<const UINT*>(desc.EndPadding), desc.DimensionCount)),
OperatorField(&DML_MAX_POOLING2_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast<const UINT*>(desc.Dilations), desc.DimensionCount)),
};
}
inline std::vector<OperatorField> GetFields(const DML_SLICE1_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_SLICE1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_SLICE1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
OperatorField(&DML_SLICE1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<UINT>(desc.DimensionCount))),
OperatorField(&DML_SLICE1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const UINT*>(desc.InputWindowOffsets), desc.DimensionCount)),
OperatorField(&DML_SLICE1_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const UINT*>(desc.InputWindowSizes), desc.DimensionCount)),
OperatorField(&DML_SLICE1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<const INT*>(desc.InputWindowStrides), desc.DimensionCount)),
};
}
inline std::vector<OperatorField> GetFields(const DML_TOP_K1_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_TOP_K1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_TOP_K1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputValueTensor))),
OperatorField(&DML_TOP_K1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputIndexTensor))),
OperatorField(&DML_TOP_K1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<UINT>(desc.Axis))),
OperatorField(&DML_TOP_K1_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<UINT>(desc.K))),
OperatorField(&DML_TOP_K1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<UINT>(desc.AxisDirection))),
};
}
inline std::vector<OperatorField> GetFields(const DML_DEPTH_TO_SPACE1_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_DEPTH_TO_SPACE1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_DEPTH_TO_SPACE1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
OperatorField(&DML_DEPTH_TO_SPACE1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<UINT>(desc.BlockSize))),
OperatorField(&DML_DEPTH_TO_SPACE1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<UINT>(desc.Order))),
};
}
inline std::vector<OperatorField> GetFields(const DML_SPACE_TO_DEPTH1_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_SPACE_TO_DEPTH1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_SPACE_TO_DEPTH1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
OperatorField(&DML_SPACE_TO_DEPTH1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<UINT>(desc.BlockSize))),
OperatorField(&DML_SPACE_TO_DEPTH1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<UINT>(desc.Order))),
};
}
inline std::vector<OperatorField> GetFields(const DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.ScaleTensor))),
OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BiasTensor))),
OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<UINT>(desc.AxisCount))),
OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<const UINT*>(desc.Axes), desc.AxisCount)),
OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<UINT>(desc.NormalizeVariance))),
OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<FLOAT>(desc.Epsilon))),
OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast<const DML_OPERATOR_DESC*>(desc.FusedActivation))),
};
}
inline std::vector<OperatorField> GetFields(const DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.ATensor))),
OperatorField(&DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.AScaleTensor))),
OperatorField(&DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.AZeroPointTensor))),
OperatorField(&DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BTensor))),
OperatorField(&DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BScaleTensor))),
OperatorField(&DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BZeroPointTensor))),
OperatorField(&DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputScaleTensor))),
OperatorField(&DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputZeroPointTensor))),
OperatorField(&DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
};
}
inline std::vector<OperatorField> GetFields(const DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputScaleTensor))),
OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputZeroPointTensor))),
OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.FilterTensor))),
OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.FilterScaleTensor))),
OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.FilterZeroPointTensor))),
OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BiasTensor))),
OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputScaleTensor))),
OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputZeroPointTensor))),
OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast<UINT>(desc.DimensionCount))),
OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast<const UINT*>(desc.Strides), desc.DimensionCount)),
OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[12], ToOperatorFieldType(static_cast<const UINT*>(desc.Dilations), desc.DimensionCount)),
OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[13], ToOperatorFieldType(static_cast<const UINT*>(desc.StartPadding), desc.DimensionCount)),
OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[14], ToOperatorFieldType(static_cast<const UINT*>(desc.EndPadding), desc.DimensionCount)),
OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[15], ToOperatorFieldType(static_cast<UINT>(desc.GroupCount))),
};
}
inline std::vector<OperatorField> GetFields(const DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
OperatorField(&DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputScaleTensor))),
OperatorField(&DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputZeroPointTensor))),
};
}
inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_ELU_OPERATOR_DESC& desc)
{
return {
@ -1097,6 +1209,15 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType)
case DML_OPERATOR_GATHER_ELEMENTS: return DML_GATHER_ELEMENTS_OPERATOR_SCHEMA;
case DML_OPERATOR_GATHER_ND: return DML_GATHER_ND_OPERATOR_SCHEMA;
case DML_OPERATOR_SCATTER_ND: return DML_SCATTER_ND_OPERATOR_SCHEMA;
case DML_OPERATOR_MAX_POOLING2: return DML_MAX_POOLING2_OPERATOR_SCHEMA;
case DML_OPERATOR_SLICE1: return DML_SLICE1_OPERATOR_SCHEMA;
case DML_OPERATOR_TOP_K1: return DML_TOP_K1_OPERATOR_SCHEMA;
case DML_OPERATOR_DEPTH_TO_SPACE1: return DML_DEPTH_TO_SPACE1_OPERATOR_SCHEMA;
case DML_OPERATOR_SPACE_TO_DEPTH1: return DML_SPACE_TO_DEPTH1_OPERATOR_SCHEMA;
case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1: return DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA;
case DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY: return DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA;
case DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION: return DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA;
case DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR: return DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_ELU: return DML_ACTIVATION_ELU_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_HARDMAX: return DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_HARD_SIGMOID: return DML_ACTIVATION_HARD_SIGMOID_OPERATOR_SCHEMA;
@ -1485,6 +1606,42 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
return AbstractOperatorDesc(
&DML_SCATTER_ND_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_SCATTER_ND_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_MAX_POOLING2:
return AbstractOperatorDesc(
&DML_MAX_POOLING2_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_MAX_POOLING2_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_SLICE1:
return AbstractOperatorDesc(
&DML_SLICE1_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_SLICE1_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_TOP_K1:
return AbstractOperatorDesc(
&DML_TOP_K1_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_TOP_K1_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_DEPTH_TO_SPACE1:
return AbstractOperatorDesc(
&DML_DEPTH_TO_SPACE1_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_DEPTH_TO_SPACE1_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_SPACE_TO_DEPTH1:
return AbstractOperatorDesc(
&DML_SPACE_TO_DEPTH1_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_SPACE_TO_DEPTH1_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1:
return AbstractOperatorDesc(
&DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY:
return AbstractOperatorDesc(
&DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION:
return AbstractOperatorDesc(
&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR:
return AbstractOperatorDesc(
&DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_ACTIVATION_ELU:
return AbstractOperatorDesc(
&DML_ACTIVATION_ELU_OPERATOR_SCHEMA,

View file

@ -16,7 +16,7 @@ using ApiAttributeVariant = std::variant<
const DML_SCALE_BIAS*,
DML_SIZE_2D,
DML_SCALAR_UNION
>;
>;
namespace OperatorFieldTypes
{
@ -51,7 +51,7 @@ using OperatorFieldVariant = std::variant<
OperatorFieldTypes::ScaleBias,
OperatorFieldTypes::Size2D,
OperatorFieldTypes::ScalarUnion
>;
>;
class OperatorField
{

View file

@ -1732,7 +1732,8 @@ void InferAndVerifyOutputSizes(
for (uint32_t output_dim = 0; output_dim < outputShapes.GetShape(outputIndex).size(); ++output_dim) {
if (shape.dim(output_dim).has_dim_value()) {
int64_t expected_size = shape.dim(output_dim).dim_value();
ML_CHECK_BOOL(expected_size == outputShapes.GetShape(outputIndex)[output_dim]);
int64_t actual_size = outputShapes.GetShape(outputIndex)[output_dim];
ML_CHECK_BOOL(expected_size == actual_size);
}
}
}

View file

@ -13,49 +13,36 @@ public:
: DmlOperator(kernelInfo),
SliceHelperBase(kernelInfo, kernelInfo.GetTensorShapeDescription(), opsetVersion)
{
uint32_t minInputCount = (opsetVersion < 10) ? 1 : 3;
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() >= minInputCount);
const uint32_t inputCount = kernelInfo.GetInputCount();
ML_CHECK_VALID_ARGUMENT((opsetVersion < 10 && inputCount == 1)
|| (opsetVersion >= 10 && opsetVersion <= 11 && inputCount >= 3 && inputCount <= 5));
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);
// TODO (23108599): Slice V10 introduces an optional "Steps" input which the kernel does not yet support.
THROW_HR_IF(E_NOTIMPL, kernelInfo.GetInputCount() > 4);
std::vector<std::optional<uint32_t>> kernelInputIndices = { 0 };
std::vector<std::optional<uint32_t>> kernelInputIndices = { 0 }; // Only bind GPU to first 'data' tensor.
DmlOperator::Initialize(kernelInfo, kernelInputIndices);
assert(m_inputTensorDescs[0].GetDimensionCount() >= gsl::narrow_cast<uint32_t>(m_offsets.size()));
assert(m_inputTensorDescs[0].GetDimensionCount() >= gsl::narrow_cast<uint32_t>(m_sizes.size()));
assert(m_inputTensorDescs[0].GetDimensionCount() >= gsl::narrow_cast<uint32_t>(m_strides.size()));
const uint32_t inputTensorRank = m_inputTensorDescs[0].GetDimensionCount();
assert(inputTensorRank >= gsl::narrow_cast<uint32_t>(m_offsets.size()));
assert(inputTensorRank >= gsl::narrow_cast<uint32_t>(m_sizes.size()));
assert(inputTensorRank >= gsl::narrow_cast<uint32_t>(m_strides.size()));
// Pad the parameters to respect DML's requirements
m_offsets.insert(
m_offsets.begin(),
m_inputTensorDescs[0].GetDimensionCount() - gsl::narrow_cast<uint32_t>(m_offsets.size()),
0);
m_sizes.insert(
m_sizes.begin(),
m_inputTensorDescs[0].GetDimensionCount() - gsl::narrow_cast<uint32_t>(m_sizes.size()),
1);
m_strides.insert(
m_strides.begin(),
m_inputTensorDescs[0].GetDimensionCount() - gsl::narrow_cast<uint32_t>(m_strides.size()),
1);
FillWithLeadingValues(/*inout*/ m_offsets, inputTensorRank, 0u);
FillWithLeadingValues(/*inout*/ m_sizes, inputTensorRank, 1u);
FillWithLeadingValues(/*inout*/ m_strides, inputTensorRank, 1);
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
DML_SLICE_OPERATOR_DESC sliceDesc = {};
DML_SLICE1_OPERATOR_DESC sliceDesc = {};
sliceDesc.InputTensor = inputDescs.data();
sliceDesc.OutputTensor = outputDescs.data();
sliceDesc.DimensionCount = gsl::narrow_cast<uint32_t>(m_offsets.size());
sliceDesc.Offsets = m_offsets.data();
sliceDesc.Sizes = m_sizes.data();
sliceDesc.Strides = m_strides.data();
sliceDesc.InputWindowOffsets = m_offsets.data();
sliceDesc.InputWindowSizes = m_sizes.data();
sliceDesc.InputWindowStrides = m_strides.data();
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_SLICE, &sliceDesc };
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_SLICE1, &sliceDesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
}
};
@ -73,7 +60,7 @@ public:
void CALLBACK QuerySlice(IMLOperatorSupportQueryContextPrivate* context, bool* isSupported)
{
*isSupported = (context->GetInputCount() <= 4);
*isSupported = (context->GetInputCount() <= 5);
}
DML_OP_DEFINE_CREATION_FUNCTION(Slice7, DmlOperatorSliceTemplate<7>);

View file

@ -233,6 +233,7 @@ const static char* const typeNameListT1T2[2] = { "T1", "T2" };
const static char* const typeNameListConstantOfShape[2] = { "T1", "T2" };
const static char* const typeNameListScatterGather[2] = { "T", "Tind" };
const static char* const typeNameListScatterGatherND[1] = { "T" }; // Tind is curiously missing, only allowing 64-bit.
const static char* const typeNameListSlice10[2] = { "T", "Tind" };
const static char* const typeNameListQuantize[2] = { "T1", "T2" };
const static char* const typeNameListWhere[2] = { "B", "T" };
const static char* const typeNameListOneHot[3] = { "T1", "T2", "T3" };
@ -249,6 +250,7 @@ const static SupportedTensorDataTypes supportedTypeListIndices[1] = { SupportedT
const static SupportedTensorDataTypes supportedTypeListCast[2] = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Scalars8to32 };
const static SupportedTensorDataTypes supportedTypeListScatterGather[2] = { SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 };
const static SupportedTensorDataTypes supportedTypeListScatterGatherND[1] = { SupportedTensorDataTypes::NumericDefault };
const static SupportedTensorDataTypes supportedTypeListSlice10[2] = { SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 };
const static SupportedTensorDataTypes supportedTypeListQuantizeLinear[2] = { SupportedTensorDataTypes::Float32 | SupportedTensorDataTypes::Int32, SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 };
const static SupportedTensorDataTypes supportedTypeListDequantizeLinear[2] = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 | SupportedTensorDataTypes::Int32 };
const static SupportedTensorDataTypes supportedTypeListQuantize[2] = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::UInt8 };
@ -332,9 +334,8 @@ const static OperatorRegistrationInformation operatorRegistrationInformationTabl
{REG_INFO( 11, Concat, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)}, // Adds negative axis.
{REG_INFO_VER( 7, Slice, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)},
{REG_INFO_VER( 10, Slice, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported, {1, 2, 3}, std::nullopt, QuerySlice)},
#if 0 // TODO:DwayneR
{REG_INFO_VER( 11, Slice, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported, {1, 2, 3}, std::nullopt, QuerySlice)}, // Adds negative axes.
#endif
{REG_INFO_VER( 10, Slice, typeNameListSlice10, supportedTypeListSlice10, DmGraphSupport::Supported, {1, 2, 3, 4}, std::nullopt, QuerySlice)}, // Adds negative axes.
{REG_INFO_VER( 11, Slice, typeNameListSlice10, supportedTypeListSlice10, DmGraphSupport::Supported, {1, 2, 3, 4}, std::nullopt, QuerySlice)},
{REG_INFO( 7, Pad, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
#if 0 // TODO:NickFe Pads and Value are inputs. https://microsoft.visualstudio.com/OS/_workitems/edit/24674281, https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Pad-11
{REG_INFO( 11, Pad, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},

View file

@ -297,7 +297,7 @@ void TensorDesc::ForceUnsignedDataType()
m_bufferTensorDesc.DataType = DML_TENSOR_DATA_TYPE_UINT8;
break;
// Nothing to do if already unsigned
// Nothing to do if already unsigned
case DML_TENSOR_DATA_TYPE_UINT32:
case DML_TENSOR_DATA_TYPE_UINT16:
case DML_TENSOR_DATA_TYPE_UINT8:

View file

@ -563,6 +563,20 @@ public:
return m_impl->GetOutputCount();
}
// Returns true if an input to the operator is valid.
// This returns false for optional omitted inputs and invalid indices.
bool IsInputValid(uint32_t inputIndex) const noexcept
{
return m_impl->IsInputValid(inputIndex);
}
// Returns true if an output to the operator is valid.
// This returns false for optional omitted inputs and invalid indices.
bool IsOutputValid(uint32_t inputIndex) const noexcept
{
return m_impl->IsOutputValid(inputIndex);
}
MLOperatorEdgeDescription GetInputEdgeDescription(uint32_t inputIndex) const
{
MLOperatorEdgeDescription ret;

View file

@ -32,6 +32,54 @@ namespace OperatorHelper
}
}
void ReadCpuLocalTensorIntoInt32(
const MLOperatorTensor& tensor,
std::vector<int32_t>& result
)
{
result.clear();
ML_CHECK_VALID_ARGUMENT(tensor.IsCpuData(), "Tensor must be CPU Tensor.");
const std::vector<uint32_t>& tensorDimensions = tensor.GetShape();
const uint32_t elementCount = ComputeElementCountFromDimensions(tensorDimensions);
switch (tensor.GetTensorDataType())
{
case MLOperatorTensorDataType::Int32:
{
const int32_t* data = tensor.GetData<int32_t>();
result.assign(data, data + elementCount);
}
break;
case MLOperatorTensorDataType::Int64:
{
const int64_t* data = tensor.GetData<int64_t>();
result.reserve(elementCount);
for (auto d : gsl::make_span(data, data + elementCount))
{
result.push_back(gsl::narrow_cast<int32_t>(d));
}
}
break;
default:
ML_INVALID_ARGUMENT("Expecting CPU local tensor of type int32 or int64.");
break;
}
}
void DowncastDimensions(gsl::span<const int64_t> inputDimensions, std::vector<DimensionType>& outputDimensions)
{
outputDimensions.reserve(inputDimensions.size());
outputDimensions.clear();
for (int64_t dim : inputDimensions)
{
outputDimensions.push_back(gsl::narrow_cast<uint32_t>(std::clamp<int64_t>(dim, INT32_MIN, INT32_MAX)));
}
}
int64_t ReadAsInt64(MLOperatorTensorDataType tensorDataType, const void* p)
{
switch (tensorDataType)
@ -495,7 +543,7 @@ namespace OperatorHelper
return edgeShapes;
}
std::vector<EdgeShapes> SliceHelperBase::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
std::vector<EdgeShapes> SliceHelperBase::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
{
return { m_outputDimensions };
}
@ -1157,10 +1205,7 @@ namespace OperatorHelper
// First element of shape tensor is how many dims to expand to.
std::vector<uint32_t> desiredTensorShape;
for (int64_t dim : gsl::make_span(shapeData, dimCount))
{
desiredTensorShape.push_back(gsl::narrow_cast<uint32_t>(dim));
}
DowncastDimensions(gsl::make_span(shapeData, dimCount), /*out*/ desiredTensorShape);
// Determine the broadcasted input shape.
outputDimensions = OperatorHelper::BroadcastTensorShape(actualInputTensorShape, desiredTensorShape);
@ -1185,10 +1230,7 @@ namespace OperatorHelper
// First element of shape tensor is how many dims to expand to.
std::vector<uint32_t> desiredTensorShape;
for (int64_t dim : gsl::make_span(shapeData, dimCount))
{
desiredTensorShape.push_back(gsl::narrow_cast<uint32_t>(dim));
}
DowncastDimensions(gsl::make_span(shapeData, dimCount), /*out*/ desiredTensorShape);
return { std::move(EdgeShapes(desiredTensorShape)) };
}

View file

@ -75,12 +75,32 @@ void RemoveValuesByIndex(gsl::span<const uint32_t> indices, bool keepOneValue, /
values.resize(newValuesCount);
}
template <typename T>
void FillWithLeadingValues(/*inout*/ std::vector<T>& values, uint32_t minimumElementCount, T fillValue)
{
// e.g.
// input = [6,7]
// elementCount = 4
// fillValue = 1
// output = [1,1,6,7]
const size_t oldElementCount = values.size();
const size_t newElementCount = std::max(size_t(minimumElementCount), oldElementCount);
const size_t fillCount = newElementCount - oldElementCount;
values.resize(newElementCount);
std::copy_backward(values.begin(), values.begin() + oldElementCount, values.end());
std::fill_n(values.data(), fillCount, fillValue);
}
int64_t ReadAsInt64(MLOperatorTensorDataType tensorDataType, const void* p);
double ReadAsFloat64(MLOperatorTensorDataType tensorDataType, const void* p);
void ReadScalarTensorData(const MLOperatorTensor& tensor, /*out*/ void* data, size_t dataByteSize);
int64_t ReadScalarTensorAsInt64(const MLOperatorTensor& tensor);
double ReadScalarTensorAsFloat64(const MLOperatorTensor& tensor);
void ReadCpuLocalTensorIntoInt32(const MLOperatorTensor& tensor, std::vector<int32_t>& result);
class EdgeShapes {
public:
EdgeShapes() = default;
@ -530,57 +550,25 @@ class SplitHelper {
class SliceHelperBase
{
public:
template<typename Info_t, typename Index_t>
template<typename Info_t>
void ReadIndexTensors(
const Info_t& operatorInfo,
std::vector<int32_t>& starts,
std::vector<int32_t>& ends,
std::vector<int32_t>& axes,
std::vector<int32_t>& steps
)
/*out*/ std::vector<int32_t>& starts,
/*out*/ std::vector<int32_t>& ends,
/*out*/ std::vector<int32_t>& axes,
/*out*/ std::vector<int32_t>& steps
)
{
// Get starts, ends, optional axes and optional steps from constant inputs.
MLOperatorTensor startsTensor = operatorInfo.GetConstantInputTensor(1);
const std::vector<uint32_t>& startsTensorDimensions = startsTensor.GetShape();
size_t dimCount = startsTensorDimensions[0];
const Index_t* startsData = startsTensor.GetData<Index_t>();
for (size_t i = 0; i < dimCount; ++i)
// Get starts, ends, optional axes, and optional steps from constant inputs.
ReadCpuLocalTensorIntoInt32(operatorInfo.GetConstantInputTensor(1), /*out*/ starts);
ReadCpuLocalTensorIntoInt32(operatorInfo.GetConstantInputTensor(2), /*out*/ ends);
if (operatorInfo.IsInputValid(3))
{
starts.push_back(gsl::narrow_cast<int32_t>(startsData[i]));
ReadCpuLocalTensorIntoInt32(operatorInfo.GetConstantInputTensor(3), /*out*/ axes);
}
MLOperatorTensor endsTensor = operatorInfo.GetConstantInputTensor(2);
const std::vector<uint32_t>& endsTensorDimensions = endsTensor.GetShape();
dimCount = endsTensorDimensions[0];
const Index_t* endsData = endsTensor.GetData<Index_t>();
for (size_t i = 0; i < dimCount; ++i)
if (operatorInfo.IsInputValid(4))
{
ends.push_back(gsl::narrow_cast<int32_t>(endsData[i]));
}
uint32_t inputCount = operatorInfo.GetInputCount();
if (inputCount > 3)
{
MLOperatorTensor axesTensor = operatorInfo.GetConstantInputTensor(3);
const std::vector<uint32_t>& axesTensorDimensions = axesTensor.GetShape();
dimCount = axesTensorDimensions[0];
const Index_t* axesData = axesTensor.GetData<Index_t>();
for (size_t i = 0; i < dimCount; ++i)
{
axes.push_back(gsl::narrow_cast<int32_t>(axesData[i]));
}
}
if (inputCount > 4)
{
MLOperatorTensor stepsTensor = operatorInfo.GetConstantInputTensor(4);
const std::vector<uint32_t>& stepsTensorDimensions = stepsTensor.GetShape();
dimCount = stepsTensorDimensions[0];
const Index_t* stepsData = stepsTensor.GetData<Index_t>();
for (size_t i = 0; i < dimCount; ++i)
{
steps.push_back(gsl::narrow_cast<int32_t>(stepsData[i]));
}
ReadCpuLocalTensorIntoInt32(operatorInfo.GetConstantInputTensor(4), /*out*/ steps);
}
}
@ -595,35 +583,30 @@ public:
std::vector<int32_t> ends;
std::vector<int32_t> axes;
std::vector<int32_t> steps;
if (opsetVersion == 7)
{
// Get starts, ends and axes from attributes
// Read starts, ends, and axes from attributes.
starts = operatorInfo.GetOptionalAttributeVectorInt32(AttrName::Starts);
ends = operatorInfo.GetOptionalAttributeVectorInt32(AttrName::Ends);
axes = operatorInfo.GetOptionalAttributeVectorInt32(AttrName::Axes);
}
else if (opsetVersion == 10)
else if (opsetVersion == 10 || opsetVersion == 11)
{
if (operatorInfo.GetConstantInputTensor(1).GetTensorDataType() == MLOperatorTensorDataType::Int32)
{
ReadIndexTensors<Info_t, int32_t>(operatorInfo, starts, ends, axes, steps);
}
else
{
THROW_HR_IF(E_INVALIDARG, operatorInfo.GetConstantInputTensor(1).GetTensorDataType() != MLOperatorTensorDataType::Int64);
ReadIndexTensors<Info_t, int64_t>(operatorInfo, starts, ends, axes, steps);
}
// Read starts, ends, and axes from tensors.
ReadIndexTensors(operatorInfo, /*out*/ starts, /*out*/ ends, /*out*/ axes, /*out*/ steps);
}
const uint32_t dimCount = gsl::narrow_cast<int32_t>(inputDimensions.size());
HandleNegativeAxes(/*inout*/ axes, dimCount);
const uint32_t inputDimensionCount = gsl::narrow_cast<int32_t>(inputDimensions.size());
HandleNegativeAxes(/*inout*/ axes, inputDimensionCount);
ML_CHECK_VALID_ARGUMENT(starts.size() == ends.size(), "'starts' must equal 'ends' in size.");
ML_CHECK_VALID_ARGUMENT(axes.empty() || starts.size() == axes.size(), "'axes' must equal 'starts' in size, or 'axes' must be empty.");
m_outputDimensions.assign(inputDimensions.begin(), inputDimensions.end());
m_offsets.resize(m_outputDimensions.size());
m_sizes.resize(m_outputDimensions.size());
m_strides = std::move(steps);
m_strides.resize(m_outputDimensions.size(), 1); // Only a stride of 1 element is supported by ONNX 1.2.
// Set initial defaults lest 'starts' and 'ends' arrays are shorter than the dimension count.
@ -632,18 +615,28 @@ public:
// Clamp selected dimensions to given 'starts' and 'ends'.
for (int i = 0, ci = gsl::narrow_cast<int>(starts.size()); i < ci; ++i)
{
int dimIndex = i;
if (!axes.empty())
{
dimIndex = axes[i];
}
int dimIndex = axes.empty() ? i : axes[i];
int stride = m_strides[i];
ML_CHECK_VALID_ARGUMENT(dimIndex < inputDimensions.size(), "'axes' must be valid with within actual input dimensions.");
ML_CHECK_VALID_ARGUMENT(stride != 0, "'steps' must not be 0.");
// Positive values are offsets from 0.
// Negative values are offsets from the dimension's size.
// Negative values are offsets from back of the dimension's size.
// INT_MIN is a special value in ONNX which means to treat it as the smallest
// possible value, rather than the usual reversed from-the-back semantics.
int dim = gsl::narrow_cast<int>(inputDimensions[dimIndex]);
int start = (starts[i] < 0) ? (starts[i] + dim) : starts[i];
int end = (ends[i] < 0) ? (ends[i] + dim) : ends[i];
int start = (starts[i] < 0 && starts[i] > INT_MIN) ? (starts[i] + dim) : starts[i];
int end = (ends[i] < 0 && starts[i] > INT_MIN) ? (ends[i] + dim) : ends[i];
// For negative strides, the ONNX start and end values are off-by-one.
// So fix them such that the start value remains the minimum extent
// of the slice window, and end remains the maximum exclusive extent.
if (stride < 0)
{
std::swap(start, end);
start += (start < INT_MAX) ? 1 : 0; // Avoid overflow wrap.
end += (end < INT_MAX) ? 1 : 0;
}
// Clamp the dimensions to the slice extents.
// Clamp negative numbers to 0, per case test_slice_start_out_of_bounds.
@ -651,7 +644,11 @@ public:
end = std::min(end, dim);
int size = std::max(end - start, 0);
m_outputDimensions[dimIndex] = size;
// Set the input window offsets/sizes, and compute output size based on input
// window size (rounding up).
// e.g. a window size 13 and step 3 yields 5 output elements.
int absoluteStride = abs(stride);
m_outputDimensions[dimIndex] = (size / absoluteStride) + (size % absoluteStride != 0);
m_offsets[dimIndex] = start;
m_sizes[dimIndex] = gsl::narrow_cast<uint32_t>(size);
}
@ -671,7 +668,7 @@ public:
std::vector<DimensionType> m_outputDimensions;
std::vector<uint32_t> m_offsets;
std::vector<uint32_t> m_sizes;
std::vector<uint32_t> m_strides;
std::vector<int32_t> m_strides;
};
class SliceHelper : public SliceHelperBase
@ -1282,7 +1279,7 @@ using ShapeInferenceHelper_Transpose = TransposeHelper;
using ShapeInferenceHelper_Concat = ConcatHelper;
using ShapeInferenceHelper_Slice7 = SliceHelper;
using ShapeInferenceHelper_Slice10 = Slice10Helper;
using ShapeInferenceHelper_Slice11 = Slice10Helper; // 11 and 10 are identical.
using ShapeInferenceHelper_Slice11 = Slice10Helper; // 11 and 10 are identical - no functional change.
using ShapeInferenceHelper_Pad = PaddingHelper;
using ShapeInferenceHelper_SpaceToDepth = SpaceToDepthHelper;
using ShapeInferenceHelper_DepthToSpace = DepthToSpaceHelper;