mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-28 03:20:58 +00:00
Merged PR 4986854: Opset 12: Clip, Max, Min, MaxPool, ReduceMax, ReduceMin
Add Clip-12, Max-12 (Adds int support). Add MaxPool-12, ReduceMax-12, ReduceMin-12 (int8 support) windowsai pr https://microsoft.visualstudio.com/WindowsAI/_git/WindowsAI/pullrequest/4983894
This commit is contained in:
parent
c67dd693c2
commit
f74e55cfc9
4 changed files with 30 additions and 4 deletions
|
|
@ -462,6 +462,9 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
// Same operator signature as 11. Only difference is new type support
|
||||
using DmlOperatorElementwiseClip12 = DmlOperatorElementwiseClip11;
|
||||
|
||||
class DmlOperatorElementwisePow : public DmlOperator
|
||||
{
|
||||
public:
|
||||
|
|
@ -718,6 +721,7 @@ DML_OP_DEFINE_CREATION_FUNCTION(Mean, DmlOperatorElementwiseMean);
|
|||
// Operators with extra attributes:
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Clip7, DmlOperatorElementwiseClip7);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Clip11, DmlOperatorElementwiseClip11);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Clip12, DmlOperatorElementwiseClip12);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Pow, DmlOperatorElementwisePow);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(QuantizeLinear, DmlOperatorElementwiseQLinear<DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_DESC>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(DequantizeLinear, DmlOperatorElementwiseQLinear<DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_DESC>);
|
||||
|
|
|
|||
|
|
@ -34,13 +34,15 @@ enum class SupportedTensorDataTypes : uint32_t
|
|||
UInt64 = 1<<13,
|
||||
Complex64 = 1<<14,
|
||||
Complex128 = 1<<15,
|
||||
Int8to32 = UInt8|Int8|UInt16|Int16|UInt32|Int32,
|
||||
Ints8to32 = UInt8|Int8|UInt16|Int16|UInt32|Int32,
|
||||
Int32to64 = UInt32|Int32|UInt64|Int64,
|
||||
Float16to32 = Float16|Float32, // Float64 is not supported by DirectML.
|
||||
NumericDefault = Int8to32|Float16to32,
|
||||
NumericDefault = Ints8to32|Float16to32,
|
||||
Scalars8to32 = UInt8|Int8|UInt16|Int16|UInt32|Int32|Float16to32|Bool,
|
||||
AllScalars = UInt8|Int8|UInt16|Int16|UInt32|Int32|UInt64|Int64|Float16to32|Bool,
|
||||
Ints8Bit = UInt8|Int8,
|
||||
Ints16Bit = UInt16|Int16,
|
||||
Ints32Bit = UInt32|Int32,
|
||||
All = static_cast<uint32_t>(-1),
|
||||
};
|
||||
DEFINE_ENUM_FLAG_OPERATORS(Dml::SupportedTensorDataTypes);
|
||||
|
|
@ -121,6 +123,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Ceil);
|
|||
DML_OP_EXTERN_CREATION_FUNCTION(Floor);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Clip7);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Clip11);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Clip12);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Greater);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Less);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Equal);
|
||||
|
|
@ -253,8 +256,10 @@ constexpr static std::array<const char*, 1> typeNameListEyeLike = { "T2" };
|
|||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListAll = {SupportedTensorDataTypes::All};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat32 = {SupportedTensorDataTypes::Float32};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat16to32 = {SupportedTensorDataTypes::Float16to32};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat16to32Int8 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat16to32Int32 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::UInt32};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListInt8to32 = {SupportedTensorDataTypes::Int8to32};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat16to32Int8to32 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit | SupportedTensorDataTypes::Ints16Bit | SupportedTensorDataTypes::Ints32Bit};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListInt8to32 = {SupportedTensorDataTypes::Ints8to32};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListInt32to64AndFloat16to32 = {SupportedTensorDataTypes::Int32to64|SupportedTensorDataTypes::Float16to32};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListNumericDefault = { SupportedTensorDataTypes::NumericDefault };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListAllScalars = { SupportedTensorDataTypes::AllScalars };
|
||||
|
|
@ -340,7 +345,8 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
|
|||
{REG_INFO( 8, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
|
||||
{REG_INFO( 10, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
|
||||
{REG_INFO( 11, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
|
||||
|
||||
{REG_INFO( 12, MaxPool, typeNameListDefault, supportedTypeListFloat16to32Int8, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
|
||||
|
||||
{REG_INFO( 7, GlobalMaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
|
|
@ -407,6 +413,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
|
|||
{REG_INFO( 7, Floor, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 7, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 11, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1,2))},
|
||||
{REG_INFO_VER( 12, Clip, typeNameListDefault, supportedTypeListFloat16to32Int8to32,DmlGraphSupport::Supported, requiredConstantCpuInputs(1,2))},
|
||||
{REG_INFO( 7, Add, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Sub, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Mul, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)},
|
||||
|
|
@ -417,8 +424,10 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
|
|||
{REG_INFO( 8, Mean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 7, Max, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 8, Max, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 12, Max, typeNameListDefault, supportedTypeListFloat16to32Int8to32,DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 7, Min, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 8, Min, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 12, Min, typeNameListDefault, supportedTypeListFloat16to32Int8to32,DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
|
||||
{REG_INFO( 7, Cos, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Sin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Tan, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
|
|
@ -457,8 +466,10 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
|
|||
{REG_INFO( 11, ReduceL2, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 12, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32Int8, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 12, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32Int8, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, ArgMax, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO( 11, ArgMax, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
{REG_INFO( 7, ArgMin, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
|
||||
|
|
|
|||
|
|
@ -1382,6 +1382,7 @@ using ShapeInferenceHelper_Ceil = GetOutputShapeAsInputShapeHelper;
|
|||
using ShapeInferenceHelper_Floor = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_Clip7 = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_Clip11 = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_Clip12 = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_Greater = GetBroadcastedOutputShapeHelper;
|
||||
using ShapeInferenceHelper_Less = GetBroadcastedOutputShapeHelper;
|
||||
using ShapeInferenceHelper_Equal = GetBroadcastedOutputShapeHelper;
|
||||
|
|
|
|||
|
|
@ -245,6 +245,16 @@ namespace OperatorHelper
|
|||
static const int sc_sinceVer_Unsqueeze = 11;
|
||||
} // namespace OnnxOperatorSet11
|
||||
|
||||
namespace OnnxOperatorSet12
|
||||
{
|
||||
static const int sc_sinceVer_Clip = 12;
|
||||
static const int sc_sinceVer_Min = 12;
|
||||
static const int sc_sinceVer_Max = 12;
|
||||
static const int sc_sinceVer_MaxPool = 12;
|
||||
static const int sc_sinceVer_ReduceMax = 12;
|
||||
static const int sc_sinceVer_ReduceMin = 12;
|
||||
}
|
||||
|
||||
namespace MsftOperatorSet1
|
||||
{
|
||||
static const int sc_sinceVer_FusedConv = 1;
|
||||
|
|
|
|||
Loading…
Reference in a new issue