From f74e55cfc984aead35c995b70afd7122d877bb21 Mon Sep 17 00:00:00 2001 From: Jake Mathern Date: Mon, 3 Aug 2020 19:35:34 +0000 Subject: [PATCH] Merged PR 4986854: Opset 12: Clip, Max, Min, MaxPool, ReduceMax, ReduceMin Add Clip-12, Max-12 (Adds int support). Add MaxPool-12, ReduceMax-12, ReduceMin-12 (int8 support) windowsai pr https://microsoft.visualstudio.com/WindowsAI/_git/WindowsAI/pullrequest/4983894 --- .../src/Operators/DmlOperatorElementWise.cpp | 4 ++++ .../src/Operators/OperatorRegistration.cpp | 19 +++++++++++++++---- .../dml/OperatorAuthorHelper/OperatorHelper.h | 1 + .../OperatorRegistration.h | 10 ++++++++++ 4 files changed, 30 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp index 78a8d9e417..987a04f93f 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp @@ -462,6 +462,9 @@ public: } }; +// Same operator signature as 11. Only difference is new type support +using DmlOperatorElementwiseClip12 = DmlOperatorElementwiseClip11; + class DmlOperatorElementwisePow : public DmlOperator { public: @@ -718,6 +721,7 @@ DML_OP_DEFINE_CREATION_FUNCTION(Mean, DmlOperatorElementwiseMean); // Operators with extra attributes: DML_OP_DEFINE_CREATION_FUNCTION(Clip7, DmlOperatorElementwiseClip7); DML_OP_DEFINE_CREATION_FUNCTION(Clip11, DmlOperatorElementwiseClip11); +DML_OP_DEFINE_CREATION_FUNCTION(Clip12, DmlOperatorElementwiseClip12); DML_OP_DEFINE_CREATION_FUNCTION(Pow, DmlOperatorElementwisePow); DML_OP_DEFINE_CREATION_FUNCTION(QuantizeLinear, DmlOperatorElementwiseQLinear); DML_OP_DEFINE_CREATION_FUNCTION(DequantizeLinear, DmlOperatorElementwiseQLinear); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index f22d097683..6bad1a5df5 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -34,13 +34,15 @@ enum class SupportedTensorDataTypes : uint32_t UInt64 = 1<<13, Complex64 = 1<<14, Complex128 = 1<<15, - Int8to32 = UInt8|Int8|UInt16|Int16|UInt32|Int32, + Ints8to32 = UInt8|Int8|UInt16|Int16|UInt32|Int32, Int32to64 = UInt32|Int32|UInt64|Int64, Float16to32 = Float16|Float32, // Float64 is not supported by DirectML. - NumericDefault = Int8to32|Float16to32, + NumericDefault = Ints8to32|Float16to32, Scalars8to32 = UInt8|Int8|UInt16|Int16|UInt32|Int32|Float16to32|Bool, AllScalars = UInt8|Int8|UInt16|Int16|UInt32|Int32|UInt64|Int64|Float16to32|Bool, Ints8Bit = UInt8|Int8, + Ints16Bit = UInt16|Int16, + Ints32Bit = UInt32|Int32, All = static_cast(-1), }; DEFINE_ENUM_FLAG_OPERATORS(Dml::SupportedTensorDataTypes); @@ -121,6 +123,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Ceil); DML_OP_EXTERN_CREATION_FUNCTION(Floor); DML_OP_EXTERN_CREATION_FUNCTION(Clip7); DML_OP_EXTERN_CREATION_FUNCTION(Clip11); +DML_OP_EXTERN_CREATION_FUNCTION(Clip12); DML_OP_EXTERN_CREATION_FUNCTION(Greater); DML_OP_EXTERN_CREATION_FUNCTION(Less); DML_OP_EXTERN_CREATION_FUNCTION(Equal); @@ -253,8 +256,10 @@ constexpr static std::array typeNameListEyeLike = { "T2" }; constexpr static std::array supportedTypeListAll = {SupportedTensorDataTypes::All}; constexpr static std::array supportedTypeListFloat32 = {SupportedTensorDataTypes::Float32}; constexpr static std::array supportedTypeListFloat16to32 = {SupportedTensorDataTypes::Float16to32}; +constexpr static std::array supportedTypeListFloat16to32Int8 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit }; constexpr static std::array supportedTypeListFloat16to32Int32 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::UInt32}; -constexpr static std::array supportedTypeListInt8to32 = {SupportedTensorDataTypes::Int8to32}; +constexpr static std::array supportedTypeListFloat16to32Int8to32 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit | SupportedTensorDataTypes::Ints16Bit | SupportedTensorDataTypes::Ints32Bit}; +constexpr static std::array supportedTypeListInt8to32 = {SupportedTensorDataTypes::Ints8to32}; constexpr static std::array supportedTypeListInt32to64AndFloat16to32 = {SupportedTensorDataTypes::Int32to64|SupportedTensorDataTypes::Float16to32}; constexpr static std::array supportedTypeListNumericDefault = { SupportedTensorDataTypes::NumericDefault }; constexpr static std::array supportedTypeListAllScalars = { SupportedTensorDataTypes::AllScalars }; @@ -340,7 +345,8 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 8, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, {REG_INFO( 10, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, {REG_INFO( 11, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, - + {REG_INFO( 12, MaxPool, typeNameListDefault, supportedTypeListFloat16to32Int8, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, + {REG_INFO( 7, GlobalMaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 11, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, @@ -407,6 +413,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 7, Floor, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO_VER( 7, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO_VER( 11, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1,2))}, + {REG_INFO_VER( 12, Clip, typeNameListDefault, supportedTypeListFloat16to32Int8to32,DmlGraphSupport::Supported, requiredConstantCpuInputs(1,2))}, {REG_INFO( 7, Add, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)}, {REG_INFO( 7, Sub, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)}, {REG_INFO( 7, Mul, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)}, @@ -417,8 +424,10 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 8, Mean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, {REG_INFO( 7, Max, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, {REG_INFO( 8, Max, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 12, Max, typeNameListDefault, supportedTypeListFloat16to32Int8to32,DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, {REG_INFO( 7, Min, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, {REG_INFO( 8, Min, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 12, Min, typeNameListDefault, supportedTypeListFloat16to32Int8to32,DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, {REG_INFO( 7, Cos, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, Sin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, Tan, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, @@ -457,8 +466,10 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 11, ReduceL2, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 11, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 12, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32Int8, DmlGraphSupport::Supported)}, {REG_INFO( 7, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 11, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 12, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32Int8, DmlGraphSupport::Supported)}, {REG_INFO( 7, ArgMax, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, {REG_INFO( 11, ArgMax, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, {REG_INFO( 7, ArgMin, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index b06c81679e..abe0cbfa3c 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -1382,6 +1382,7 @@ using ShapeInferenceHelper_Ceil = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Floor = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Clip7 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Clip11 = GetOutputShapeAsInputShapeHelper; +using ShapeInferenceHelper_Clip12 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Greater = GetBroadcastedOutputShapeHelper; using ShapeInferenceHelper_Less = GetBroadcastedOutputShapeHelper; using ShapeInferenceHelper_Equal = GetBroadcastedOutputShapeHelper; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorRegistration.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorRegistration.h index bb696c8fde..cfe23a2de7 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorRegistration.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorRegistration.h @@ -245,6 +245,16 @@ namespace OperatorHelper static const int sc_sinceVer_Unsqueeze = 11; } // namespace OnnxOperatorSet11 + namespace OnnxOperatorSet12 + { + static const int sc_sinceVer_Clip = 12; + static const int sc_sinceVer_Min = 12; + static const int sc_sinceVer_Max = 12; + static const int sc_sinceVer_MaxPool = 12; + static const int sc_sinceVer_ReduceMax = 12; + static const int sc_sinceVer_ReduceMin = 12; + } + namespace MsftOperatorSet1 { static const int sc_sinceVer_FusedConv = 1;