diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 077f7127bf..98d21d85f8 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -937,9 +937,9 @@ Do not modify directly.* |||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |DequantizeLinear|*in* x:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*out* y:**tensor(float)**

or

*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|13+|**T** = tensor(int32), tensor(int8), tensor(uint8)| |||10+|**T** = tensor(int32), tensor(int8), tensor(uint8)| -|Div|*in* A:**T**
*in* B:**T**
*out* C:**T**|14+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint8)| -|||13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint8)| -|||7+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint8)| +|Div|*in* A:**T**
*in* B:**T**
*out* C:**T**|14+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||7+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Dropout|*in* data:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T2**

or

*in* data:**T**
*out* output:**T**
*out* mask:**T**

or

*in* data:**T**
*out* output:**T**
*out* mask:**T1**|7+|**T** = tensor(float), tensor(float16)| |DynamicQuantizeLinear|*in* x:**T1**
*out* y:**T2**
*out* y_scale:**tensor(float)**
*out* y_zero_point:**T2**|11+|**T1** = tensor(float)
**T2** = tensor(uint8)| |Einsum|*in* Inputs:**T**
*out* Output:**T**|12+|**T** = tensor(float), tensor(float16)| @@ -1052,8 +1052,8 @@ Do not modify directly.* |||7+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Neg|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8)| |||6+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8)| -|NonZero|*in* X:**T**
*out* Y:**tensor(int64)**|13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint8)| -|||9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint8)| +|NonZero|*in* X:**T**
*out* Y:**tensor(int64)**|13+|**T** = tensor(bool), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint8)| +|||9+|**T** = tensor(bool), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint8)| |Not|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(bool)| |OneHot|*in* indices:**T1**
*in* depth:**T2**
*in* values:**T3**
*out* output:**T3**|11+|**T1** = tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T3** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||9+|**T1** = tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T3** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 13919420ee..60d706b6b4 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -534,6 +534,7 @@ constexpr static std::array supportedTypeListSize = constexpr static std::array supportedTypeListQLinearSigmoid = {SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8}; constexpr static std::array supportedTypeListAttention = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int32}; constexpr static std::array supportedTypeListGroupNorm = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float16to32}; +constexpr static std::array supportedTypeListNonZero = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit | SupportedTensorDataTypes::Ints16Bit | SupportedTensorDataTypes::Ints32Bit | SupportedTensorDataTypes::Bool}; constexpr static std::array supportedTypeListQLinearMatMul = { SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, @@ -733,9 +734,9 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 7, Mul, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported)}, {REG_INFO( 13, Mul, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported)}, {REG_INFO( 14, Mul, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Div, typeNameListDefault, supportedTypeListFloat16to32Ints8to32, DmlGraphSupport::Supported)}, - {REG_INFO( 13, Div, typeNameListDefault, supportedTypeListFloat16to32Ints8to32, DmlGraphSupport::Supported)}, - {REG_INFO( 14, Div, typeNameListDefault, supportedTypeListFloat16to32Ints8to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Div, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported)}, + {REG_INFO( 13, Div, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported)}, + {REG_INFO( 14, Div, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported)}, {REG_INFO( 7, Sum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, {REG_INFO( 8, Sum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, {REG_INFO( 13, Sum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, @@ -919,8 +920,8 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 15, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)}, {REG_INFO( 7, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)}, {REG_INFO( 13, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)}, - {REG_INFO_DYNAMIC_OUTPUTS( 9, NonZero, typeNameListDefault, supportedTypeListFloat16to32Ints8to32, DmlGraphSupport::NotSupported)}, - {REG_INFO_DYNAMIC_OUTPUTS(13, NonZero, typeNameListDefault, supportedTypeListFloat16to32Ints8to32, DmlGraphSupport::NotSupported)}, + {REG_INFO_DYNAMIC_OUTPUTS( 9, NonZero, typeNameListDefault, supportedTypeListNonZero, DmlGraphSupport::NotSupported)}, + {REG_INFO_DYNAMIC_OUTPUTS(13, NonZero, typeNameListDefault, supportedTypeListNonZero, DmlGraphSupport::NotSupported)}, // DmlFused operators {REG_INFO_MSDML(1, DmlFusedConv, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},