diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h index 4847bc4d8a..9a9ce5d672 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h @@ -25,7 +25,7 @@ struct EnumTraits template <> struct EnumTraits { - static constexpr auto ValueCount = 110; + static constexpr auto ValueCount = 119; static constexpr size_t ActivationFunctionCount = 19; }; @@ -96,6 +96,12 @@ struct EnumTraits static constexpr auto ValueCount = 3; }; +template <> +struct EnumTraits +{ + static constexpr auto ValueCount = 2; +}; + template <> struct EnumTraits { @@ -706,6 +712,60 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_SCATTER_ND; }; +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MAX_POOLING2; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_SLICE1; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_TOP_K1; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_DEPTH_TO_SPACE1; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_SPACE_TO_DEPTH1; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR; +}; + template <> struct OperatorDescTraits { @@ -1366,6 +1426,60 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SCATTER_ND> using DescType = DML_SCATTER_ND_OPERATOR_DESC; }; +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MAX_POOLING2> +{ + using DescType = DML_MAX_POOLING2_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SLICE1> +{ + using DescType = DML_SLICE1_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_TOP_K1> +{ + using DescType = DML_TOP_K1_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_DEPTH_TO_SPACE1> +{ + using DescType = DML_DEPTH_TO_SPACE1_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SPACE_TO_DEPTH1> +{ + using DescType = DML_SPACE_TO_DEPTH1_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1> +{ + using DescType = DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY> +{ + using DescType = DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION> +{ + using DescType = DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR> +{ + using DescType = DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC; +}; + template <> struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_ELU> { @@ -1673,6 +1787,24 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args return std::invoke(std::forward(visitor), DML_GATHER_ND_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_SCATTER_ND: return std::invoke(std::forward(visitor), DML_SCATTER_ND_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_MAX_POOLING2: + return std::invoke(std::forward(visitor), DML_MAX_POOLING2_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_SLICE1: + return std::invoke(std::forward(visitor), DML_SLICE1_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_TOP_K1: + return std::invoke(std::forward(visitor), DML_TOP_K1_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_DEPTH_TO_SPACE1: + return std::invoke(std::forward(visitor), DML_DEPTH_TO_SPACE1_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_SPACE_TO_DEPTH1: + return std::invoke(std::forward(visitor), DML_SPACE_TO_DEPTH1_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1: + return std::invoke(std::forward(visitor), DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY: + return std::invoke(std::forward(visitor), DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION: + return std::invoke(std::forward(visitor), DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR: + return std::invoke(std::forward(visitor), DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ACTIVATION_ELU: return std::invoke(std::forward(visitor), DML_ACTIVATION_ELU_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ACTIVATION_HARDMAX: @@ -1813,6 +1945,15 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_GATHER_ELEMENTS: return "DML_OPERATOR_GATHER_ELEMENTS"; case DML_OPERATOR_GATHER_ND: return "DML_OPERATOR_GATHER_ND"; case DML_OPERATOR_SCATTER_ND: return "DML_OPERATOR_SCATTER_ND"; + case DML_OPERATOR_MAX_POOLING2: return "DML_OPERATOR_MAX_POOLING2"; + case DML_OPERATOR_SLICE1: return "DML_OPERATOR_SLICE1"; + case DML_OPERATOR_TOP_K1: return "DML_OPERATOR_TOP_K1"; + case DML_OPERATOR_DEPTH_TO_SPACE1: return "DML_OPERATOR_DEPTH_TO_SPACE1"; + case DML_OPERATOR_SPACE_TO_DEPTH1: return "DML_OPERATOR_SPACE_TO_DEPTH1"; + case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1: return "DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1"; + case DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY: return "DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY"; + case DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION: return "DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION"; + case DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR: return "DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR"; default: assert(false); return ""; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h index ff2a5e0fbd..a80653a78b 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h @@ -1441,6 +1441,172 @@ constexpr DML_OPERATOR_SCHEMA DML_SCATTER_ND_OPERATOR_SCHEMA { DML_SCATTER_ND_OPERATOR_SCHEMA_FIELDS, }; +constexpr DML_SCHEMA_FIELD DML_MAX_POOLING2_OPERATOR_SCHEMA_FIELDS[9] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputIndicesTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_MAX_POOLING2_OPERATOR_SCHEMA { + "DML_OPERATOR_MAX_POOLING2", + DML_OPERATOR_MAX_POOLING2, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 9, + DML_MAX_POOLING2_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_SLICE1_OPERATOR_SCHEMA_FIELDS[6] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "InputWindowOffsets", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "InputWindowSizes", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_INT_ARRAY, "InputWindowStrides", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_SLICE1_OPERATOR_SCHEMA { + "DML_OPERATOR_SLICE1", + DML_OPERATOR_SLICE1, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 6, + DML_SLICE1_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_TOP_K1_OPERATOR_SCHEMA_FIELDS[6] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputValueTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputIndexTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Axis", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "K", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AxisDirection", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_TOP_K1_OPERATOR_SCHEMA { + "DML_OPERATOR_TOP_K1", + DML_OPERATOR_TOP_K1, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 6, + DML_TOP_K1_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_DEPTH_TO_SPACE1_OPERATOR_SCHEMA_FIELDS[4] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "BlockSize", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Order", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_DEPTH_TO_SPACE1_OPERATOR_SCHEMA { + "DML_OPERATOR_DEPTH_TO_SPACE1", + DML_OPERATOR_DEPTH_TO_SPACE1, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 4, + DML_DEPTH_TO_SPACE1_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_SPACE_TO_DEPTH1_OPERATOR_SCHEMA_FIELDS[4] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "BlockSize", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Order", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_SPACE_TO_DEPTH1_OPERATOR_SCHEMA { + "DML_OPERATOR_SPACE_TO_DEPTH1", + DML_OPERATOR_SPACE_TO_DEPTH1, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 4, + DML_SPACE_TO_DEPTH1_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA_FIELDS[9] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ScaleTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AxisCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Axes", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "NormalizeVariance", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Epsilon", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC, "FusedActivation", true }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA { + "DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1", + DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 9, + DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA_FIELDS[9] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AZeroPointTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BZeroPointTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputZeroPointTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA { + "DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY", + DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 9, + DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA_FIELDS[16] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputZeroPointTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "FilterTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "FilterScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "FilterZeroPointTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputZeroPointTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "GroupCount", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA { + "DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION", + DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 16, + DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA_FIELDS[4] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputZeroPointTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA { + "DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR", + DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 4, + DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA_FIELDS, +}; + constexpr DML_SCHEMA_FIELD DML_ACTIVATION_ELU_OPERATOR_SCHEMA_FIELDS[3] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h index 04411a94fe..6d33af8a9c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h @@ -852,6 +852,118 @@ inline std::vector GetFields(const DML_SCATTER_ND_OPERATOR_DESC& OperatorField(&DML_SCATTER_ND_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.IndicesDimensionCount))), }; } +inline std::vector GetFields(const DML_MAX_POOLING2_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_MAX_POOLING2_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_MAX_POOLING2_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_MAX_POOLING2_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputIndicesTensor))), + OperatorField(&DML_MAX_POOLING2_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.DimensionCount))), + OperatorField(&DML_MAX_POOLING2_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)), + OperatorField(&DML_MAX_POOLING2_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.WindowSize), desc.DimensionCount)), + OperatorField(&DML_MAX_POOLING2_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)), + OperatorField(&DML_MAX_POOLING2_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)), + OperatorField(&DML_MAX_POOLING2_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.Dilations), desc.DimensionCount)), + }; +} +inline std::vector GetFields(const DML_SLICE1_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_SLICE1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_SLICE1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_SLICE1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.DimensionCount))), + OperatorField(&DML_SLICE1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.InputWindowOffsets), desc.DimensionCount)), + OperatorField(&DML_SLICE1_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.InputWindowSizes), desc.DimensionCount)), + OperatorField(&DML_SLICE1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.InputWindowStrides), desc.DimensionCount)), + }; +} +inline std::vector GetFields(const DML_TOP_K1_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_TOP_K1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_TOP_K1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputValueTensor))), + OperatorField(&DML_TOP_K1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputIndexTensor))), + OperatorField(&DML_TOP_K1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Axis))), + OperatorField(&DML_TOP_K1_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.K))), + OperatorField(&DML_TOP_K1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.AxisDirection))), + }; +} +inline std::vector GetFields(const DML_DEPTH_TO_SPACE1_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_DEPTH_TO_SPACE1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_DEPTH_TO_SPACE1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_DEPTH_TO_SPACE1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.BlockSize))), + OperatorField(&DML_DEPTH_TO_SPACE1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Order))), + }; +} +inline std::vector GetFields(const DML_SPACE_TO_DEPTH1_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_SPACE_TO_DEPTH1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_SPACE_TO_DEPTH1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_SPACE_TO_DEPTH1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.BlockSize))), + OperatorField(&DML_SPACE_TO_DEPTH1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Order))), + }; +} +inline std::vector GetFields(const DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.ScaleTensor))), + OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.BiasTensor))), + OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.AxisCount))), + OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.Axes), desc.AxisCount)), + OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.NormalizeVariance))), + OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.Epsilon))), + OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.FusedActivation))), + }; +} +inline std::vector GetFields(const DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), + OperatorField(&DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.AScaleTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.AZeroPointTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.BTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.BScaleTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.BZeroPointTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.OutputScaleTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.OutputZeroPointTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.InputScaleTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.InputZeroPointTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.FilterTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.FilterScaleTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.FilterZeroPointTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.BiasTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.OutputScaleTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.OutputZeroPointTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast(desc.DimensionCount))), + OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[12], ToOperatorFieldType(static_cast(desc.Dilations), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[13], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[14], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[15], ToOperatorFieldType(static_cast(desc.GroupCount))), + }; +} +inline std::vector GetFields(const DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputScaleTensor))), + OperatorField(&DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.OutputZeroPointTensor))), + }; +} inline std::vector GetFields(const DML_ACTIVATION_ELU_OPERATOR_DESC& desc) { return { @@ -1097,6 +1209,15 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType) case DML_OPERATOR_GATHER_ELEMENTS: return DML_GATHER_ELEMENTS_OPERATOR_SCHEMA; case DML_OPERATOR_GATHER_ND: return DML_GATHER_ND_OPERATOR_SCHEMA; case DML_OPERATOR_SCATTER_ND: return DML_SCATTER_ND_OPERATOR_SCHEMA; + case DML_OPERATOR_MAX_POOLING2: return DML_MAX_POOLING2_OPERATOR_SCHEMA; + case DML_OPERATOR_SLICE1: return DML_SLICE1_OPERATOR_SCHEMA; + case DML_OPERATOR_TOP_K1: return DML_TOP_K1_OPERATOR_SCHEMA; + case DML_OPERATOR_DEPTH_TO_SPACE1: return DML_DEPTH_TO_SPACE1_OPERATOR_SCHEMA; + case DML_OPERATOR_SPACE_TO_DEPTH1: return DML_SPACE_TO_DEPTH1_OPERATOR_SCHEMA; + case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1: return DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA; + case DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY: return DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA; + case DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION: return DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA; + case DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR: return DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_ELU: return DML_ACTIVATION_ELU_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_HARDMAX: return DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_HARD_SIGMOID: return DML_ACTIVATION_HARD_SIGMOID_OPERATOR_SCHEMA; @@ -1485,6 +1606,42 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) return AbstractOperatorDesc( &DML_SCATTER_ND_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_MAX_POOLING2: + return AbstractOperatorDesc( + &DML_MAX_POOLING2_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_SLICE1: + return AbstractOperatorDesc( + &DML_SLICE1_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_TOP_K1: + return AbstractOperatorDesc( + &DML_TOP_K1_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_DEPTH_TO_SPACE1: + return AbstractOperatorDesc( + &DML_DEPTH_TO_SPACE1_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_SPACE_TO_DEPTH1: + return AbstractOperatorDesc( + &DML_SPACE_TO_DEPTH1_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1: + return AbstractOperatorDesc( + &DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY: + return AbstractOperatorDesc( + &DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION: + return AbstractOperatorDesc( + &DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR: + return AbstractOperatorDesc( + &DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); case DML_OPERATOR_ACTIVATION_ELU: return AbstractOperatorDesc( &DML_ACTIVATION_ELU_OPERATOR_SCHEMA, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaTypes.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaTypes.h index 21c8c46af7..25f0dd26c6 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaTypes.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaTypes.h @@ -16,7 +16,7 @@ using ApiAttributeVariant = std::variant< const DML_SCALE_BIAS*, DML_SIZE_2D, DML_SCALAR_UNION ->; + >; namespace OperatorFieldTypes { @@ -51,7 +51,7 @@ using OperatorFieldVariant = std::variant< OperatorFieldTypes::ScaleBias, OperatorFieldTypes::Size2D, OperatorFieldTypes::ScalarUnion ->; + >; class OperatorField { diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp index 0c4e3fb9ea..47cbcae00f 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp @@ -1732,7 +1732,8 @@ void InferAndVerifyOutputSizes( for (uint32_t output_dim = 0; output_dim < outputShapes.GetShape(outputIndex).size(); ++output_dim) { if (shape.dim(output_dim).has_dim_value()) { int64_t expected_size = shape.dim(output_dim).dim_value(); - ML_CHECK_BOOL(expected_size == outputShapes.GetShape(outputIndex)[output_dim]); + int64_t actual_size = outputShapes.GetShape(outputIndex)[output_dim]; + ML_CHECK_BOOL(expected_size == actual_size); } } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSlice.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSlice.cpp index bbeb9d9a87..9837e79269 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSlice.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSlice.cpp @@ -13,49 +13,36 @@ public: : DmlOperator(kernelInfo), SliceHelperBase(kernelInfo, kernelInfo.GetTensorShapeDescription(), opsetVersion) { - uint32_t minInputCount = (opsetVersion < 10) ? 1 : 3; - ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() >= minInputCount); + const uint32_t inputCount = kernelInfo.GetInputCount(); + ML_CHECK_VALID_ARGUMENT((opsetVersion < 10 && inputCount == 1) + || (opsetVersion >= 10 && opsetVersion <= 11 && inputCount >= 3 && inputCount <= 5)); ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1); - // TODO (23108599): Slice V10 introduces an optional "Steps" input which the kernel does not yet support. - THROW_HR_IF(E_NOTIMPL, kernelInfo.GetInputCount() > 4); - - std::vector> kernelInputIndices = { 0 }; + std::vector> kernelInputIndices = { 0 }; // Only bind GPU to first 'data' tensor. DmlOperator::Initialize(kernelInfo, kernelInputIndices); - assert(m_inputTensorDescs[0].GetDimensionCount() >= gsl::narrow_cast(m_offsets.size())); - assert(m_inputTensorDescs[0].GetDimensionCount() >= gsl::narrow_cast(m_sizes.size())); - assert(m_inputTensorDescs[0].GetDimensionCount() >= gsl::narrow_cast(m_strides.size())); + const uint32_t inputTensorRank = m_inputTensorDescs[0].GetDimensionCount(); + assert(inputTensorRank >= gsl::narrow_cast(m_offsets.size())); + assert(inputTensorRank >= gsl::narrow_cast(m_sizes.size())); + assert(inputTensorRank >= gsl::narrow_cast(m_strides.size())); // Pad the parameters to respect DML's requirements - m_offsets.insert( - m_offsets.begin(), - m_inputTensorDescs[0].GetDimensionCount() - gsl::narrow_cast(m_offsets.size()), - 0); - - m_sizes.insert( - m_sizes.begin(), - m_inputTensorDescs[0].GetDimensionCount() - gsl::narrow_cast(m_sizes.size()), - 1); - - m_strides.insert( - m_strides.begin(), - m_inputTensorDescs[0].GetDimensionCount() - gsl::narrow_cast(m_strides.size()), - 1); + FillWithLeadingValues(/*inout*/ m_offsets, inputTensorRank, 0u); + FillWithLeadingValues(/*inout*/ m_sizes, inputTensorRank, 1u); + FillWithLeadingValues(/*inout*/ m_strides, inputTensorRank, 1); std::vector inputDescs = GetDmlInputDescs(); std::vector outputDescs = GetDmlOutputDescs(); - DML_SLICE_OPERATOR_DESC sliceDesc = {}; + DML_SLICE1_OPERATOR_DESC sliceDesc = {}; sliceDesc.InputTensor = inputDescs.data(); sliceDesc.OutputTensor = outputDescs.data(); sliceDesc.DimensionCount = gsl::narrow_cast(m_offsets.size()); - sliceDesc.Offsets = m_offsets.data(); - sliceDesc.Sizes = m_sizes.data(); - sliceDesc.Strides = m_strides.data(); + sliceDesc.InputWindowOffsets = m_offsets.data(); + sliceDesc.InputWindowSizes = m_sizes.data(); + sliceDesc.InputWindowStrides = m_strides.data(); - DML_OPERATOR_DESC opDesc = { DML_OPERATOR_SLICE, &sliceDesc }; - + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_SLICE1, &sliceDesc }; SetDmlOperatorDesc(opDesc, kernelInfo); } }; @@ -73,7 +60,7 @@ public: void CALLBACK QuerySlice(IMLOperatorSupportQueryContextPrivate* context, bool* isSupported) { - *isSupported = (context->GetInputCount() <= 4); + *isSupported = (context->GetInputCount() <= 5); } DML_OP_DEFINE_CREATION_FUNCTION(Slice7, DmlOperatorSliceTemplate<7>); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 92f300969f..315f347960 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -233,6 +233,7 @@ const static char* const typeNameListT1T2[2] = { "T1", "T2" }; const static char* const typeNameListConstantOfShape[2] = { "T1", "T2" }; const static char* const typeNameListScatterGather[2] = { "T", "Tind" }; const static char* const typeNameListScatterGatherND[1] = { "T" }; // Tind is curiously missing, only allowing 64-bit. +const static char* const typeNameListSlice10[2] = { "T", "Tind" }; const static char* const typeNameListQuantize[2] = { "T1", "T2" }; const static char* const typeNameListWhere[2] = { "B", "T" }; const static char* const typeNameListOneHot[3] = { "T1", "T2", "T3" }; @@ -249,6 +250,7 @@ const static SupportedTensorDataTypes supportedTypeListIndices[1] = { SupportedT const static SupportedTensorDataTypes supportedTypeListCast[2] = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Scalars8to32 }; const static SupportedTensorDataTypes supportedTypeListScatterGather[2] = { SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 }; const static SupportedTensorDataTypes supportedTypeListScatterGatherND[1] = { SupportedTensorDataTypes::NumericDefault }; +const static SupportedTensorDataTypes supportedTypeListSlice10[2] = { SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 }; const static SupportedTensorDataTypes supportedTypeListQuantizeLinear[2] = { SupportedTensorDataTypes::Float32 | SupportedTensorDataTypes::Int32, SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 }; const static SupportedTensorDataTypes supportedTypeListDequantizeLinear[2] = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 | SupportedTensorDataTypes::Int32 }; const static SupportedTensorDataTypes supportedTypeListQuantize[2] = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::UInt8 }; @@ -332,9 +334,8 @@ const static OperatorRegistrationInformation operatorRegistrationInformationTabl {REG_INFO( 11, Concat, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)}, // Adds negative axis. {REG_INFO_VER( 7, Slice, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)}, {REG_INFO_VER( 10, Slice, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported, {1, 2, 3}, std::nullopt, QuerySlice)}, -#if 0 // TODO:DwayneR - {REG_INFO_VER( 11, Slice, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported, {1, 2, 3}, std::nullopt, QuerySlice)}, // Adds negative axes. -#endif + {REG_INFO_VER( 10, Slice, typeNameListSlice10, supportedTypeListSlice10, DmGraphSupport::Supported, {1, 2, 3, 4}, std::nullopt, QuerySlice)}, // Adds negative axes. + {REG_INFO_VER( 11, Slice, typeNameListSlice10, supportedTypeListSlice10, DmGraphSupport::Supported, {1, 2, 3, 4}, std::nullopt, QuerySlice)}, {REG_INFO( 7, Pad, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, #if 0 // TODO:NickFe Pads and Value are inputs. https://microsoft.visualstudio.com/OS/_workitems/edit/24674281, https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Pad-11 {REG_INFO( 11, Pad, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp index 4bf97d889c..b4a286e08e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp @@ -297,7 +297,7 @@ void TensorDesc::ForceUnsignedDataType() m_bufferTensorDesc.DataType = DML_TENSOR_DATA_TYPE_UINT8; break; - // Nothing to do if already unsigned + // Nothing to do if already unsigned case DML_TENSOR_DATA_TYPE_UINT32: case DML_TENSOR_DATA_TYPE_UINT16: case DML_TENSOR_DATA_TYPE_UINT8: diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h index 7fee23b8d0..a6615bd815 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h @@ -563,6 +563,20 @@ public: return m_impl->GetOutputCount(); } + // Returns true if an input to the operator is valid. + // This returns false for optional omitted inputs and invalid indices. + bool IsInputValid(uint32_t inputIndex) const noexcept + { + return m_impl->IsInputValid(inputIndex); + } + + // Returns true if an output to the operator is valid. + // This returns false for optional omitted inputs and invalid indices. + bool IsOutputValid(uint32_t inputIndex) const noexcept + { + return m_impl->IsOutputValid(inputIndex); + } + MLOperatorEdgeDescription GetInputEdgeDescription(uint32_t inputIndex) const { MLOperatorEdgeDescription ret; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp index dbaf941203..d407a7c011 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp @@ -32,6 +32,54 @@ namespace OperatorHelper } } + void ReadCpuLocalTensorIntoInt32( + const MLOperatorTensor& tensor, + std::vector& result + ) + { + result.clear(); + ML_CHECK_VALID_ARGUMENT(tensor.IsCpuData(), "Tensor must be CPU Tensor."); + + const std::vector& tensorDimensions = tensor.GetShape(); + const uint32_t elementCount = ComputeElementCountFromDimensions(tensorDimensions); + + switch (tensor.GetTensorDataType()) + { + case MLOperatorTensorDataType::Int32: + { + const int32_t* data = tensor.GetData(); + result.assign(data, data + elementCount); + } + break; + + case MLOperatorTensorDataType::Int64: + { + const int64_t* data = tensor.GetData(); + result.reserve(elementCount); + for (auto d : gsl::make_span(data, data + elementCount)) + { + result.push_back(gsl::narrow_cast(d)); + } + } + break; + + default: + ML_INVALID_ARGUMENT("Expecting CPU local tensor of type int32 or int64."); + break; + } + } + + void DowncastDimensions(gsl::span inputDimensions, std::vector& outputDimensions) + { + outputDimensions.reserve(inputDimensions.size()); + outputDimensions.clear(); + + for (int64_t dim : inputDimensions) + { + outputDimensions.push_back(gsl::narrow_cast(std::clamp(dim, INT32_MIN, INT32_MAX))); + } + } + int64_t ReadAsInt64(MLOperatorTensorDataType tensorDataType, const void* p) { switch (tensorDataType) @@ -495,7 +543,7 @@ namespace OperatorHelper return edgeShapes; } - std::vector SliceHelperBase::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const + std::vector SliceHelperBase::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const { return { m_outputDimensions }; } @@ -1157,10 +1205,7 @@ namespace OperatorHelper // First element of shape tensor is how many dims to expand to. std::vector desiredTensorShape; - for (int64_t dim : gsl::make_span(shapeData, dimCount)) - { - desiredTensorShape.push_back(gsl::narrow_cast(dim)); - } + DowncastDimensions(gsl::make_span(shapeData, dimCount), /*out*/ desiredTensorShape); // Determine the broadcasted input shape. outputDimensions = OperatorHelper::BroadcastTensorShape(actualInputTensorShape, desiredTensorShape); @@ -1185,10 +1230,7 @@ namespace OperatorHelper // First element of shape tensor is how many dims to expand to. std::vector desiredTensorShape; - for (int64_t dim : gsl::make_span(shapeData, dimCount)) - { - desiredTensorShape.push_back(gsl::narrow_cast(dim)); - } + DowncastDimensions(gsl::make_span(shapeData, dimCount), /*out*/ desiredTensorShape); return { std::move(EdgeShapes(desiredTensorShape)) }; } diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index d80b54731e..3ab5879515 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -75,12 +75,32 @@ void RemoveValuesByIndex(gsl::span indices, bool keepOneValue, / values.resize(newValuesCount); } +template +void FillWithLeadingValues(/*inout*/ std::vector& values, uint32_t minimumElementCount, T fillValue) +{ + // e.g. + // input = [6,7] + // elementCount = 4 + // fillValue = 1 + // output = [1,1,6,7] + + const size_t oldElementCount = values.size(); + const size_t newElementCount = std::max(size_t(minimumElementCount), oldElementCount); + const size_t fillCount = newElementCount - oldElementCount; + + values.resize(newElementCount); + std::copy_backward(values.begin(), values.begin() + oldElementCount, values.end()); + std::fill_n(values.data(), fillCount, fillValue); +} + int64_t ReadAsInt64(MLOperatorTensorDataType tensorDataType, const void* p); double ReadAsFloat64(MLOperatorTensorDataType tensorDataType, const void* p); void ReadScalarTensorData(const MLOperatorTensor& tensor, /*out*/ void* data, size_t dataByteSize); int64_t ReadScalarTensorAsInt64(const MLOperatorTensor& tensor); double ReadScalarTensorAsFloat64(const MLOperatorTensor& tensor); +void ReadCpuLocalTensorIntoInt32(const MLOperatorTensor& tensor, std::vector& result); + class EdgeShapes { public: EdgeShapes() = default; @@ -530,57 +550,25 @@ class SplitHelper { class SliceHelperBase { public: - template + template void ReadIndexTensors( const Info_t& operatorInfo, - std::vector& starts, - std::vector& ends, - std::vector& axes, - std::vector& steps - ) + /*out*/ std::vector& starts, + /*out*/ std::vector& ends, + /*out*/ std::vector& axes, + /*out*/ std::vector& steps + ) { - // Get starts, ends, optional axes and optional steps from constant inputs. - MLOperatorTensor startsTensor = operatorInfo.GetConstantInputTensor(1); - const std::vector& startsTensorDimensions = startsTensor.GetShape(); - size_t dimCount = startsTensorDimensions[0]; - const Index_t* startsData = startsTensor.GetData(); - for (size_t i = 0; i < dimCount; ++i) + // Get starts, ends, optional axes, and optional steps from constant inputs. + ReadCpuLocalTensorIntoInt32(operatorInfo.GetConstantInputTensor(1), /*out*/ starts); + ReadCpuLocalTensorIntoInt32(operatorInfo.GetConstantInputTensor(2), /*out*/ ends); + if (operatorInfo.IsInputValid(3)) { - starts.push_back(gsl::narrow_cast(startsData[i])); + ReadCpuLocalTensorIntoInt32(operatorInfo.GetConstantInputTensor(3), /*out*/ axes); } - - MLOperatorTensor endsTensor = operatorInfo.GetConstantInputTensor(2); - const std::vector& endsTensorDimensions = endsTensor.GetShape(); - dimCount = endsTensorDimensions[0]; - const Index_t* endsData = endsTensor.GetData(); - for (size_t i = 0; i < dimCount; ++i) + if (operatorInfo.IsInputValid(4)) { - ends.push_back(gsl::narrow_cast(endsData[i])); - } - - uint32_t inputCount = operatorInfo.GetInputCount(); - if (inputCount > 3) - { - MLOperatorTensor axesTensor = operatorInfo.GetConstantInputTensor(3); - const std::vector& axesTensorDimensions = axesTensor.GetShape(); - dimCount = axesTensorDimensions[0]; - const Index_t* axesData = axesTensor.GetData(); - for (size_t i = 0; i < dimCount; ++i) - { - axes.push_back(gsl::narrow_cast(axesData[i])); - } - } - - if (inputCount > 4) - { - MLOperatorTensor stepsTensor = operatorInfo.GetConstantInputTensor(4); - const std::vector& stepsTensorDimensions = stepsTensor.GetShape(); - dimCount = stepsTensorDimensions[0]; - const Index_t* stepsData = stepsTensor.GetData(); - for (size_t i = 0; i < dimCount; ++i) - { - steps.push_back(gsl::narrow_cast(stepsData[i])); - } + ReadCpuLocalTensorIntoInt32(operatorInfo.GetConstantInputTensor(4), /*out*/ steps); } } @@ -595,35 +583,30 @@ public: std::vector ends; std::vector axes; std::vector steps; + if (opsetVersion == 7) { - // Get starts, ends and axes from attributes + // Read starts, ends, and axes from attributes. starts = operatorInfo.GetOptionalAttributeVectorInt32(AttrName::Starts); ends = operatorInfo.GetOptionalAttributeVectorInt32(AttrName::Ends); axes = operatorInfo.GetOptionalAttributeVectorInt32(AttrName::Axes); } - else if (opsetVersion == 10) + else if (opsetVersion == 10 || opsetVersion == 11) { - if (operatorInfo.GetConstantInputTensor(1).GetTensorDataType() == MLOperatorTensorDataType::Int32) - { - ReadIndexTensors(operatorInfo, starts, ends, axes, steps); - } - else - { - THROW_HR_IF(E_INVALIDARG, operatorInfo.GetConstantInputTensor(1).GetTensorDataType() != MLOperatorTensorDataType::Int64); - ReadIndexTensors(operatorInfo, starts, ends, axes, steps); - } + // Read starts, ends, and axes from tensors. + ReadIndexTensors(operatorInfo, /*out*/ starts, /*out*/ ends, /*out*/ axes, /*out*/ steps); } - const uint32_t dimCount = gsl::narrow_cast(inputDimensions.size()); - HandleNegativeAxes(/*inout*/ axes, dimCount); - + const uint32_t inputDimensionCount = gsl::narrow_cast(inputDimensions.size()); + HandleNegativeAxes(/*inout*/ axes, inputDimensionCount); + ML_CHECK_VALID_ARGUMENT(starts.size() == ends.size(), "'starts' must equal 'ends' in size."); ML_CHECK_VALID_ARGUMENT(axes.empty() || starts.size() == axes.size(), "'axes' must equal 'starts' in size, or 'axes' must be empty."); m_outputDimensions.assign(inputDimensions.begin(), inputDimensions.end()); m_offsets.resize(m_outputDimensions.size()); m_sizes.resize(m_outputDimensions.size()); + m_strides = std::move(steps); m_strides.resize(m_outputDimensions.size(), 1); // Only a stride of 1 element is supported by ONNX 1.2. // Set initial defaults lest 'starts' and 'ends' arrays are shorter than the dimension count. @@ -632,18 +615,28 @@ public: // Clamp selected dimensions to given 'starts' and 'ends'. for (int i = 0, ci = gsl::narrow_cast(starts.size()); i < ci; ++i) { - int dimIndex = i; - if (!axes.empty()) - { - dimIndex = axes[i]; - } + int dimIndex = axes.empty() ? i : axes[i]; + int stride = m_strides[i]; ML_CHECK_VALID_ARGUMENT(dimIndex < inputDimensions.size(), "'axes' must be valid with within actual input dimensions."); + ML_CHECK_VALID_ARGUMENT(stride != 0, "'steps' must not be 0."); // Positive values are offsets from 0. - // Negative values are offsets from the dimension's size. + // Negative values are offsets from back of the dimension's size. + // INT_MIN is a special value in ONNX which means to treat it as the smallest + // possible value, rather than the usual reversed from-the-back semantics. int dim = gsl::narrow_cast(inputDimensions[dimIndex]); - int start = (starts[i] < 0) ? (starts[i] + dim) : starts[i]; - int end = (ends[i] < 0) ? (ends[i] + dim) : ends[i]; + int start = (starts[i] < 0 && starts[i] > INT_MIN) ? (starts[i] + dim) : starts[i]; + int end = (ends[i] < 0 && starts[i] > INT_MIN) ? (ends[i] + dim) : ends[i]; + + // For negative strides, the ONNX start and end values are off-by-one. + // So fix them such that the start value remains the minimum extent + // of the slice window, and end remains the maximum exclusive extent. + if (stride < 0) + { + std::swap(start, end); + start += (start < INT_MAX) ? 1 : 0; // Avoid overflow wrap. + end += (end < INT_MAX) ? 1 : 0; + } // Clamp the dimensions to the slice extents. // Clamp negative numbers to 0, per case test_slice_start_out_of_bounds. @@ -651,7 +644,11 @@ public: end = std::min(end, dim); int size = std::max(end - start, 0); - m_outputDimensions[dimIndex] = size; + // Set the input window offsets/sizes, and compute output size based on input + // window size (rounding up). + // e.g. a window size 13 and step 3 yields 5 output elements. + int absoluteStride = abs(stride); + m_outputDimensions[dimIndex] = (size / absoluteStride) + (size % absoluteStride != 0); m_offsets[dimIndex] = start; m_sizes[dimIndex] = gsl::narrow_cast(size); } @@ -671,7 +668,7 @@ public: std::vector m_outputDimensions; std::vector m_offsets; std::vector m_sizes; - std::vector m_strides; + std::vector m_strides; }; class SliceHelper : public SliceHelperBase @@ -1282,7 +1279,7 @@ using ShapeInferenceHelper_Transpose = TransposeHelper; using ShapeInferenceHelper_Concat = ConcatHelper; using ShapeInferenceHelper_Slice7 = SliceHelper; using ShapeInferenceHelper_Slice10 = Slice10Helper; -using ShapeInferenceHelper_Slice11 = Slice10Helper; // 11 and 10 are identical. +using ShapeInferenceHelper_Slice11 = Slice10Helper; // 11 and 10 are identical - no functional change. using ShapeInferenceHelper_Pad = PaddingHelper; using ShapeInferenceHelper_SpaceToDepth = SpaceToDepthHelper; using ShapeInferenceHelper_DepthToSpace = DepthToSpaceHelper;