From de6ebcbb5438593dcab47b79ab42fe951bb192ad Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Tue, 20 Aug 2024 23:44:58 -0700 Subject: [PATCH] [DML] Add int4 QDQ (#21592) --- docs/OperatorKernels.md | 37 ++-- .../DmlExecutionProvider/src/DmlCommon.cpp | 2 +- .../src/MLOperatorAuthorImpl.cpp | 28 ++- .../src/Operators/DmlOperatorElementWise.cpp | 162 ++++++++++++++++-- .../src/Operators/OperatorRegistration.cpp | 83 +++++---- .../MLOperatorAuthorHelper.h | 15 ++ .../dml/OperatorAuthorHelper/OperatorHelper.h | 5 + .../OperatorAuthorHelper/OperatorVersions.h | 11 ++ .../cpu/tensor/quantize_linear_test.cc | 26 +++ 9 files changed, 307 insertions(+), 62 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 96173b5a4e..46d9e217bf 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -958,7 +958,8 @@ Do not modify directly.* |BitwiseNot|*in* X:**T**
*out* Y:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |BitwiseOr|*in* A:**T**
*in* B:**T**
*out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |BitwiseXor|*in* A:**T**
*in* B:**T**
*out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Cast|*in* input:**T1**
*out* output:**T2**|19+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Cast|*in* input:**T1**
*out* output:**T2**|21+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||19+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||13+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||9+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||6+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -992,7 +993,8 @@ Do not modify directly.* |DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|DequantizeLinear|*in* x:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*out* y:**tensor(float)**

or

*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|19+|**T1** = tensor(int32), tensor(int8), tensor(uint8)
**T2** = tensor(float), tensor(float16)| +|DequantizeLinear|*in* x:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*out* y:**tensor(float)**

or

*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|21+|**T1** = tensor(int4), tensor(int8), tensor(uint4), tensor(uint8)
**T2** = tensor(float), tensor(float16)| +|||19+|**T1** = tensor(int32), tensor(int8), tensor(uint8)
**T2** = tensor(float), tensor(float16)| |||13+|**T** = tensor(int32), tensor(int8), tensor(uint8)| |||10+|**T** = tensor(int32), tensor(int8), tensor(uint8)| |Div|*in* A:**T**
*in* B:**T**
*out* C:**T**|14+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -1107,8 +1109,8 @@ Do not modify directly.* |MeanVarianceNormalization|*in* X:**T**
*out* Y:**T**

or

*in* input:**T**
*out* output:**T**|13+|**T** = tensor(float), tensor(float16)| |||9+|**T** = tensor(float), tensor(float16)| |||1+|**T** = tensor(float), tensor(float16)| -|MemcpyFromHost|*in* X:**T**
*out* Y:**T**|1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|MemcpyToHost|*in* X:**T**
*out* Y:**T**|1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|MemcpyFromHost|*in* X:**T**
*out* Y:**T**|1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| +|MemcpyToHost|*in* X:**T**
*out* Y:**T**|1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| |Min|*in* data_0:**T**
*out* min:**T**|13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||12+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||8+|**T** = tensor(float), tensor(float16)| @@ -1145,7 +1147,8 @@ Do not modify directly.* |||7+|**T** = tensor(float), tensor(float16)| |QLinearConv|*in* x:**T1**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T1**
*in* w:**T2**
*in* w_scale:**tensor(float)**
*in* w_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*in* B:**T4**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)
**T4** = tensor(int32)| |QLinearMatMul|*in* a:**T1**
*in* a_scale:**TS**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**TS**
*in* b_zero_point:**T2**
*in* y_scale:**TS**
*in* y_zero_point:**T3**
*out* y:**T3**

or

*in* a:**T1**
*in* a_scale:**tensor(float)**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**tensor(float)**
*in* b_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)| -|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**

or

*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|19+|**T1** = tensor(float), tensor(float16), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| +|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**

or

*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|21+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(int4), tensor(int8), tensor(uint4), tensor(uint8)| +|||19+|**T1** = tensor(float), tensor(float16), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| |||13+|**T1** = tensor(float), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| |||10+|**T1** = tensor(float), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| |RNN|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*out* Y:**T**
*out* Y_h:**T**|14+|**T** = tensor(float), tensor(float16)| @@ -1199,7 +1202,8 @@ Do not modify directly.* |Relu|*in* X:**T**
*out* Y:**T**|14+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8)| |||13+|**T** = tensor(float), tensor(float16)| |||6+|**T** = tensor(float), tensor(float16)| -|Reshape|*in* data:**T**
*in* shape:**tensor(int64)**
*out* reshaped:**T**

or

*in* data:**T**
*out* reshaped:**T**|19+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Reshape|*in* data:**T**
*in* shape:**tensor(int64)**
*out* reshaped:**T**

or

*in* data:**T**
*out* reshaped:**T**|21+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||19+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||14+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||5+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -1230,10 +1234,11 @@ Do not modify directly.* |SequenceErase|*in* input_sequence:**S**
*in* position:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| |SequenceInsert|*in* input_sequence:**S**
*in* tensor:**T**
*in* position:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| |SequenceLength|*in* input_sequence:**S**
*out* length:**I**|11+|**I** = tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| -|Shape|*in* data:**T**
*out* shape:**T1**|19+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| -|||15+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| -|||13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| -|||1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|Shape|*in* data:**T**
*out* shape:**T1**|21+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|||19+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|||15+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|||13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|||1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |Shrink|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint8)| |Sigmoid|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(float), tensor(float16)| |||6+|**T** = tensor(float), tensor(float16)| @@ -1242,9 +1247,9 @@ Do not modify directly.* |SimplifiedLayerNormalization|*in* X:**T**
*in* scale:**V**
*out* Y:**V**
*out* inv_std_var:**U**|1+|**T** = tensor(float), tensor(float16)
**U** = tensor(float), tensor(float16)
**V** = tensor(float), tensor(float16)| |Sin|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(float), tensor(float16)| |Sinh|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(float), tensor(float16)| -|Size|*in* data:**T**
*out* size:**T1**|19+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| -|||13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| -|||1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|Size|*in* data:**T**
*out* size:**T1**|19+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|||13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|||1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |Slice|*in* data:**T**
*in* starts:**Tind**
*in* ends:**Tind**
*in* axes:**Tind**
*in* steps:**Tind**
*out* output:**T**

or

*in* data:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |||10+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| @@ -1262,7 +1267,8 @@ Do not modify directly.* |||2+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Sqrt|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(float), tensor(float16)| |||6+|**T** = tensor(float), tensor(float16)| -|Squeeze|*in* data:**T**
*in* axes:**tensor(int64)**
*out* squeezed:**T**

or

*in* data:**T**
*out* squeezed:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Squeeze|*in* data:**T**
*in* axes:**tensor(int64)**
*out* squeezed:**T**

or

*in* data:**T**
*out* squeezed:**T**|21+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Sub|*in* A:**T**
*in* B:**T**
*out* C:**T**|14+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -1284,7 +1290,8 @@ Do not modify directly.* |Transpose|*in* data:**T**
*out* transposed:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Trilu|*in* input:**T**
*in* k:**tensor(int64)**
*out* output:**T**|14+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Unsqueeze|*in* data:**T**
*in* axes:**tensor(int64)**
*out* expanded:**T**

or

*in* data:**T**
*out* expanded:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Unsqueeze|*in* data:**T**
*in* axes:**tensor(int64)**
*out* expanded:**T**

or

*in* data:**T**
*out* expanded:**T**|21+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Upsample|*in* X:**T**
*in* scales:**tensor(float)**
*out* Y:**T**

or

*in* X:**T**
*out* Y:**T**|10+|**T** = tensor(float), tensor(float16)| diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommon.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommon.cpp index 541254ffaf..aeb0bd9186 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommon.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommon.cpp @@ -121,7 +121,7 @@ uint32_t GetSupportedDeviceDataTypeMask(IDMLDevice* dmlDevice) uint32_t deviceTypeMask = 0u; // Form the bitmask of all supported data types. - for (uint32_t i = 0; i <= DML_TENSOR_DATA_TYPE_INT64; ++i) + for (uint32_t i = 0; i <= DML_TENSOR_DATA_TYPE_INT4; ++i) { DML_FEATURE_QUERY_TENSOR_DATA_TYPE_SUPPORT dataTypeQuery = { static_cast(i) }; DML_FEATURE_DATA_TENSOR_DATA_TYPE_SUPPORT dataTypeSupport = {}; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp index 0a2a5bbcbe..26559b54bc 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp @@ -247,6 +247,8 @@ namespace Windows::AI::MachineLearning::Adapter } ML_TENSOR_TYPE_CASE(float); + ML_TENSOR_TYPE_CASE(onnxruntime::Int4x2Base); + ML_TENSOR_TYPE_CASE(onnxruntime::Int4x2Base); ML_TENSOR_TYPE_CASE(uint8_t); ML_TENSOR_TYPE_CASE(int8_t); ML_TENSOR_TYPE_CASE(uint16_t); @@ -293,6 +295,8 @@ namespace Windows::AI::MachineLearning::Adapter return onnxruntime::DataTypeImpl::GetTensorType(); ML_TENSOR_TYPE_CASE(float); + ML_TENSOR_TYPE_CASE(onnxruntime::Int4x2Base); + ML_TENSOR_TYPE_CASE(onnxruntime::Int4x2Base); ML_TENSOR_TYPE_CASE(uint8_t); ML_TENSOR_TYPE_CASE(int8_t); ML_TENSOR_TYPE_CASE(uint16_t); @@ -314,6 +318,8 @@ namespace Windows::AI::MachineLearning::Adapter return onnxruntime::DataTypeImpl::GetSequenceTensorType(); ML_SEQUENCE_TENSOR_TYPE_CASE(float); + ML_SEQUENCE_TENSOR_TYPE_CASE(onnxruntime::Int4x2Base); + ML_SEQUENCE_TENSOR_TYPE_CASE(onnxruntime::Int4x2Base); ML_SEQUENCE_TENSOR_TYPE_CASE(uint8_t); ML_SEQUENCE_TENSOR_TYPE_CASE(int8_t); ML_SEQUENCE_TENSOR_TYPE_CASE(uint16_t); @@ -335,6 +341,8 @@ namespace Windows::AI::MachineLearning::Adapter return onnxruntime::DataTypeImpl::GetType(); ML_PRIMITIVE_TYPE_CASE(float); + ML_PRIMITIVE_TYPE_CASE(onnxruntime::Int4x2Base); + ML_PRIMITIVE_TYPE_CASE(onnxruntime::Int4x2Base); ML_PRIMITIVE_TYPE_CASE(uint8_t); ML_PRIMITIVE_TYPE_CASE(int8_t); ML_PRIMITIVE_TYPE_CASE(uint16_t); @@ -364,6 +372,12 @@ namespace Windows::AI::MachineLearning::Adapter case onnx::TensorProto_DataType_FLOAT: return MLOperatorTensorDataType::Float; + case onnx::TensorProto_DataType_UINT4: + return MLOperatorTensorDataType::UInt4; + + case onnx::TensorProto_DataType_INT4: + return MLOperatorTensorDataType::Int4; + case onnx::TensorProto_DataType_UINT8: return MLOperatorTensorDataType::UInt8; @@ -455,6 +469,12 @@ namespace Windows::AI::MachineLearning::Adapter case MLOperatorTensorDataType::Float: return "tensor(float)"; + case MLOperatorTensorDataType::UInt4: + return "tensor(uint4)"; + + case MLOperatorTensorDataType::Int4: + return "tensor(int4)"; + case MLOperatorTensorDataType::UInt8: return "tensor(uint8)"; @@ -509,6 +529,12 @@ namespace Windows::AI::MachineLearning::Adapter case MLOperatorTensorDataType::Float: return "seq(tensor(float))"; + case MLOperatorTensorDataType::UInt4: + return "seq(tensor(uint4))"; + + case MLOperatorTensorDataType::Int4: + return "seq(tensor(int4))"; + case MLOperatorTensorDataType::UInt8: return "seq(tensor(uint8))"; @@ -1518,7 +1544,7 @@ namespace Windows::AI::MachineLearning::Adapter AbstractOperatorDesc abstractDesc = SchemaHelpers::ConvertOperatorDesc(*node); m_graphNodeCreateInfo->nodes.push_back(std::make_unique(std::move(abstractDesc))); } - + // There can be operators (or kernels) which don't require any input. assert(operatorGraphDesc->inputEdgeCount == 0 || operatorGraphDesc->inputEdges != nullptr); m_graphNodeCreateInfo->inputEdges.insert( diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp index 16bb10f004..412207fd3c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp @@ -528,6 +528,7 @@ public: std::vector outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0); const uint32_t outputShapeDimCount = gsl::narrow_cast(outputShape.size()); const DML_TENSOR_DATA_TYPE inputDataType = m_inputTensorDescs[0].GetDmlDataType(); + const DML_TENSOR_DATA_TYPE outputDataType = m_outputTensorDescs[0].GetDmlDataType(); bool hasZeroPointTensor = kernelInfo.IsInputValid(2); uint32_t axis = 0; @@ -601,6 +602,141 @@ public: } }; +template +class DmlOperatorQuantization21 : public DmlOperator +{ +public: + DmlOperatorQuantization21(const MLOperatorKernelCreationContext& kernelInfo) : DmlOperator(kernelInfo) + { + ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 2 || kernelInfo.GetInputCount() == 3); + ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1); + + Initialize(kernelInfo, std::nullopt, std::nullopt); + + std::vector outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0); + const uint32_t outputShapeDimCount = gsl::narrow_cast(outputShape.size()); + const DML_TENSOR_DATA_TYPE inputDataType = m_inputTensorDescs[0].GetDmlDataType(); + const DML_TENSOR_DATA_TYPE outputDataType = m_outputTensorDescs[0].GetDmlDataType(); + bool hasZeroPointTensor = kernelInfo.IsInputValid(2); + + std::vector inputDescs = GetDmlInputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); + + std::vector quantizationTensors; + quantizationTensors.push_back(inputDescs[1]); + + const bool isSignedQuantization = + inputDataType == DML_TENSOR_DATA_TYPE_INT4 || + outputDataType == DML_TENSOR_DATA_TYPE_INT4 || + inputDataType == DML_TENSOR_DATA_TYPE_INT8 || + outputDataType == DML_TENSOR_DATA_TYPE_INT8; + + if (hasZeroPointTensor || isSignedQuantization) + { + if (hasZeroPointTensor) + { + quantizationTensors.push_back(inputDescs[2]); + } + + TOperatorDesc opDesc = {}; + opDesc.InputTensor = &inputDescs[0]; + opDesc.QuantizationType = hasZeroPointTensor ? DML_QUANTIZATION_TYPE_SCALE_ZERO_POINT : DML_QUANTIZATION_TYPE_SCALE; + opDesc.QuantizationTensorCount = static_cast(quantizationTensors.size()); + opDesc.QuantizationTensors = quantizationTensors.data(); + opDesc.OutputTensor = &outputDescs[0]; + SetDmlOperatorDesc({ApiTraits::OperatorDescTraits::Type, &opDesc}, kernelInfo); + } + else + { + // For unsigned quantization, DML uses the midpoint of the datatype when zero point isn't provided (e.g. 8 for uint4 and 128 for uint8) + // since this is the most sane default in theory that represents the midpoint of the quantization. This is also the default used by various + // other frameworks and quantization tools. But because ONNX uses a default zero point of 0 no matter if it's signed or unsigned, we need + // to generate a constant zero point tensor with a value of 0 to override dml's default value when it isn't provided. + auto zeroPointDataType = std::is_same_v ? outputDataType : inputDataType; + + // DML doesn't support int4 FILL_VALUE_CONSTANT yet, so simply create an int8 scalar and reinterpret it to an int4 tensor + auto zeroPointInt8DataType = zeroPointDataType == DML_TENSOR_DATA_TYPE_INT4 ? DML_TENSOR_DATA_TYPE_INT8 : DML_TENSOR_DATA_TYPE_UINT8; + + TensorDesc scalarTensorDesc(zeroPointInt8DataType, std::vector(m_inputTensorDescs[1].GetDimensionCount(), 1)); + DML_TENSOR_DESC scalarDmlTensorDesc = scalarTensorDesc.GetDmlDesc(); + + // Create a tensor full of zeros + DML_FILL_VALUE_CONSTANT_OPERATOR_DESC zeroPointConstantDesc = {}; + zeroPointConstantDesc.ValueDataType = zeroPointInt8DataType; + zeroPointConstantDesc.OutputTensor = &scalarDmlTensorDesc; + DML_OPERATOR_DESC zeroPointConstantDmlDesc = { DML_OPERATOR_FILL_VALUE_CONSTANT, &zeroPointConstantDesc }; + + // Broadcast the zero point tensor to match the scale tensor + TensorDesc broadcastedScalarTensorDesc(zeroPointDataType, m_inputTensorDescs[1].GetSizes(), std::vector(m_inputTensorDescs[1].GetDimensionCount(), 0)); + quantizationTensors.push_back(broadcastedScalarTensorDesc.GetDmlDesc()); + + // Create the quantize/dequantize operator + TOperatorDesc quantizationOpDesc = {}; + quantizationOpDesc.InputTensor = &inputDescs[0]; + quantizationOpDesc.QuantizationType = DML_QUANTIZATION_TYPE_SCALE_ZERO_POINT; + quantizationOpDesc.QuantizationTensorCount = static_cast(quantizationTensors.size()); + quantizationOpDesc.QuantizationTensors = quantizationTensors.data(); + quantizationOpDesc.OutputTensor = &outputDescs[0]; + DML_OPERATOR_DESC quantizationOpDmlDesc = { ApiTraits::OperatorDescTraits::Type, &quantizationOpDesc }; + + std::array opDescs = { + &zeroPointConstantDmlDesc, + &quantizationOpDmlDesc, + }; + + std::vector inputEdges; + inputEdges.reserve(2); + + std::vector intermediateEdges; + intermediateEdges.reserve(1); + + std::vector outputEdges; + outputEdges.reserve(1); + + // Create an edge between the input tensor and the quantization operator + DML_INPUT_GRAPH_EDGE_DESC inputToQuantizeEdge{}; + inputToQuantizeEdge.GraphInputIndex = 0; + inputToQuantizeEdge.ToNodeIndex = 1; + inputToQuantizeEdge.ToNodeInputIndex = 0; + inputEdges.push_back(inputToQuantizeEdge); + + // Create an edge between the scale and the quantization operator + DML_INPUT_GRAPH_EDGE_DESC scaleToQuantizeEdge{}; + scaleToQuantizeEdge.GraphInputIndex = 1; + scaleToQuantizeEdge.ToNodeIndex = 1; + scaleToQuantizeEdge.ToNodeInputIndex = 1; + inputEdges.push_back(scaleToQuantizeEdge); + + // Create an edge between the generated zero point tensor and the quantization operator + DML_INTERMEDIATE_GRAPH_EDGE_DESC zeroPointToQuantizeEdge{}; + zeroPointToQuantizeEdge.FromNodeIndex = 0; + zeroPointToQuantizeEdge.FromNodeOutputIndex = 0; + zeroPointToQuantizeEdge.ToNodeIndex = 1; + zeroPointToQuantizeEdge.ToNodeInputIndex = 2; + intermediateEdges.push_back(zeroPointToQuantizeEdge); + + // Create an edge between the output of the quantization operator and the output of the graph + DML_OUTPUT_GRAPH_EDGE_DESC quantizeToOutputEdge{}; + quantizeToOutputEdge.FromNodeIndex = 1; + quantizeToOutputEdge.FromNodeOutputIndex = 0; + quantizeToOutputEdge.GraphOutputIndex = 0; + outputEdges.push_back(quantizeToOutputEdge); + + // Create the graph + MLOperatorGraphDesc operatorGraphDesc = {}; + operatorGraphDesc.inputEdgeCount = gsl::narrow_cast(inputEdges.size()); + operatorGraphDesc.inputEdges = inputEdges.data(); + operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast(intermediateEdges.size()); + operatorGraphDesc.intermediateEdges = intermediateEdges.data(); + operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); + operatorGraphDesc.outputEdges = outputEdges.data(); + operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); + operatorGraphDesc.nodes = opDescs.data(); + SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelInfo); + } + } +}; + class DmlOperatorElementwiseIf : public DmlOperator { public: @@ -777,18 +913,20 @@ DML_OP_DEFINE_CREATION_FUNCTION(Max, DmlOperatorElementwiseBinaryLo DML_OP_DEFINE_CREATION_FUNCTION(Mean, DmlOperatorElementwiseMean); // Operators with extra attributes: -DML_OP_DEFINE_CREATION_FUNCTION(Clip7, DmlOperatorElementwiseClip7); -DML_OP_DEFINE_CREATION_FUNCTION(Clip11, DmlOperatorElementwiseClip11); -DML_OP_DEFINE_CREATION_FUNCTION(Clip12, DmlOperatorElementwiseClip12); -DML_OP_DEFINE_CREATION_FUNCTION(Clip13, DmlOperatorElementwiseClip13); -DML_OP_DEFINE_CREATION_FUNCTION(Pow, DmlOperatorElementwisePow); -DML_OP_DEFINE_CREATION_FUNCTION(QuantizeLinear, DmlOperatorElementwiseQLinear); -DML_OP_DEFINE_CREATION_FUNCTION(DequantizeLinear, DmlOperatorElementwiseQLinear); -DML_OP_DEFINE_CREATION_FUNCTION(Where, DmlOperatorElementwiseIf); -DML_OP_DEFINE_CREATION_FUNCTION(Mod, DmlOperatorElementwiseMod); -DML_OP_DEFINE_CREATION_FUNCTION(BitShift, DmlOperatorElementwiseBitShift); -DML_OP_DEFINE_CREATION_FUNCTION(IsInf, DmlOperatorElementwiseIsInf); -DML_OP_DEFINE_CREATION_FUNCTION(Round, DmlOperatorElementwiseRound); +DML_OP_DEFINE_CREATION_FUNCTION(Clip7, DmlOperatorElementwiseClip7); +DML_OP_DEFINE_CREATION_FUNCTION(Clip11, DmlOperatorElementwiseClip11); +DML_OP_DEFINE_CREATION_FUNCTION(Clip12, DmlOperatorElementwiseClip12); +DML_OP_DEFINE_CREATION_FUNCTION(Clip13, DmlOperatorElementwiseClip13); +DML_OP_DEFINE_CREATION_FUNCTION(Pow, DmlOperatorElementwisePow); +DML_OP_DEFINE_CREATION_FUNCTION(QuantizeLinear, DmlOperatorElementwiseQLinear); +DML_OP_DEFINE_CREATION_FUNCTION(DequantizeLinear, DmlOperatorElementwiseQLinear); +DML_OP_DEFINE_CREATION_FUNCTION(QuantizeLinear21, DmlOperatorQuantization21); +DML_OP_DEFINE_CREATION_FUNCTION(DequantizeLinear21, DmlOperatorQuantization21); +DML_OP_DEFINE_CREATION_FUNCTION(Where, DmlOperatorElementwiseIf); +DML_OP_DEFINE_CREATION_FUNCTION(Mod, DmlOperatorElementwiseMod); +DML_OP_DEFINE_CREATION_FUNCTION(BitShift, DmlOperatorElementwiseBitShift); +DML_OP_DEFINE_CREATION_FUNCTION(IsInf, DmlOperatorElementwiseIsInf); +DML_OP_DEFINE_CREATION_FUNCTION(Round, DmlOperatorElementwiseRound); // Fused operators: DML_OP_DEFINE_CREATION_FUNCTION(DmlFusedAdd, DmlOperatorElementwiseBinary); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index cf8f0a4b2d..db8922439e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -225,39 +225,43 @@ ONNX_OPERATOR_KERNEL_EX( namespace Dml { -enum class SupportedTensorDataTypes : uint32_t +enum class SupportedTensorDataTypes : uint64_t { - Undefined = 1<<0, - Float32 = 1<<1, - UInt8 = 1<<2, - Int8 = 1<<3, - UInt16 = 1<<4, - Int16 = 1<<5, - Int32 = 1<<6, - Int64 = 1<<7, - String = 1<<8, - Bool = 1<<9, - Float16 = 1<<10, - Float64 = 1<<11, - UInt32 = 1<<12, - UInt64 = 1<<13, - Complex64 = 1<<14, - Complex128 = 1<<15, - SequenceFloat32 = 1<<16, - SequenceUInt8 = 1<<17, - SequenceInt8 = 1<<18, - SequenceUInt16 = 1<<19, - SequenceInt16 = 1<<20, - SequenceInt32 = 1<<21, - SequenceInt64 = 1<<22, - SequenceString = 1<<23, - SequenceBool = 1<<24, - SequenceFloat16 = 1<<25, - SequenceFloat64 = 1<<26, - SequenceUInt32 = 1<<27, - SequenceUInt64 = 1<<28, - SequenceComplex64 = 1<<29, - SequenceComplex128 = 1<<30, + Undefined = 1LLU<<0, + Float32 = 1LLU<<1, + UInt4 = 1LLU<<2, + Int4 = 1LLU<<3, + UInt8 = 1LLU<<4, + Int8 = 1LLU<<5, + UInt16 = 1LLU<<6, + Int16 = 1LLU<<7, + Int32 = 1LLU<<8, + Int64 = 1LLU<<9, + String = 1LLU<<10, + Bool = 1LLU<<11, + Float16 = 1LLU<<12, + Float64 = 1LLU<<13, + UInt32 = 1LLU<<14, + UInt64 = 1LLU<<15, + Complex64 = 1LLU<<16, + Complex128 = 1LLU<<17, + SequenceFloat32 = 1LLU<<18, + SequenceUInt4 = 1LLU<<19, + SequenceInt4 = 1LLU<<20, + SequenceUInt8 = 1LLU<<21, + SequenceInt8 = 1LLU<<22, + SequenceUInt16 = 1LLU<<23, + SequenceInt16 = 1LLU<<24, + SequenceInt32 = 1LLU<<25, + SequenceInt64 = 1LLU<<26, + SequenceString = 1LLU<<27, + SequenceBool = 1LLU<<28, + SequenceFloat16 = 1LLU<<29, + SequenceFloat64 = 1LLU<<30, + SequenceUInt32 = 1LLU<<31, + SequenceUInt64 = 1LLU<<32, + SequenceComplex64 = 1LLU<<33, + SequenceComplex128 = 1LLU<<34, Ints8to32 = UInt8|Int8|UInt16|Int16|UInt32|Int32, Ints32to64 = UInt32|Int32|UInt64|Int64, Ints8to64 = UInt8|Int8|UInt16|Int16|UInt32|Int32|UInt64|Int64, @@ -273,7 +277,7 @@ enum class SupportedTensorDataTypes : uint32_t Ints16Bit = UInt16|Int16, Ints32Bit = UInt32|Int32, Ints64Bit = UInt64|Int64, - All = static_cast(-1), + All = static_cast(-1), }; DEFINE_ENUM_FLAG_OPERATORS(Dml::SupportedTensorDataTypes); @@ -463,7 +467,9 @@ DML_OP_EXTERN_CREATION_FUNCTION(DmlFusedMatMul); DML_OP_EXTERN_CREATION_FUNCTION(DmlFusedAdd); DML_OP_EXTERN_CREATION_FUNCTION(DmlFusedSum); DML_OP_EXTERN_CREATION_FUNCTION(QuantizeLinear); +DML_OP_EXTERN_CREATION_FUNCTION(QuantizeLinear21); DML_OP_EXTERN_CREATION_FUNCTION(DequantizeLinear); +DML_OP_EXTERN_CREATION_FUNCTION(DequantizeLinear21); DML_OP_EXTERN_CREATION_FUNCTION(QLinearSigmoid); DML_OP_EXTERN_CREATION_FUNCTION(Sign); DML_OP_EXTERN_CREATION_FUNCTION(IsNaN); @@ -599,8 +605,10 @@ constexpr static std::array supportedTypeListScatte constexpr static std::array supportedTypeListSlice10 = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 }; constexpr static std::array supportedTypeListQuantizeLinear = { SupportedTensorDataTypes::Float32 | SupportedTensorDataTypes::Int32, SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 }; constexpr static std::array supportedTypeListQuantizeLinear19 = { SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int32, SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 }; +constexpr static std::array supportedTypeListQuantizeLinear21 = { SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::UInt4 | SupportedTensorDataTypes::Int4 | SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 }; constexpr static std::array supportedTypeListDequantizeLinear = { SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 | SupportedTensorDataTypes::Int32 }; constexpr static std::array supportedTypeListDequantizeLinear19 = { SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 | SupportedTensorDataTypes::Int32, SupportedTensorDataTypes::Float16to32 }; +constexpr static std::array supportedTypeListDequantizeLinear21 = { SupportedTensorDataTypes::UInt4 | SupportedTensorDataTypes::Int4 | SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8, SupportedTensorDataTypes::Float16to32 }; constexpr static std::array supportedTypeListIsNan = { SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Bool }; constexpr static std::array supportedTypeListIsInf = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::Bool }; constexpr static std::array supportedTypeListConstantOfShape = { SupportedTensorDataTypes::Int64, SupportedTensorDataTypes::AllScalars }; @@ -848,13 +856,16 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO_COPY( 7, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY(11, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY(13, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, + {REG_INFO_COPY(21, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, {REG_INFO_COPY( 7, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY(11, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY(13, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, + {REG_INFO_COPY(21, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, {REG_INFO_COPY( 7, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, {REG_INFO_COPY(13, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, {REG_INFO_COPY(14, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, {REG_INFO_COPY(19, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, + {REG_INFO_COPY(21, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, // Elementwise {REG_INFO( 7, Sqrt, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, @@ -915,9 +926,11 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 10, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear, DmlGraphSupport::Supported)}, {REG_INFO( 13, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear, DmlGraphSupport::Supported)}, {REG_INFO( 19, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear19, DmlGraphSupport::Supported)}, + {REG_INFO_VER( 21, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear21, DmlGraphSupport::Supported)}, {REG_INFO( 10, DequantizeLinear, typeNameListDefault, supportedTypeListDequantizeLinear, DmlGraphSupport::Supported)}, {REG_INFO( 13, DequantizeLinear, typeNameListDefault, supportedTypeListDequantizeLinear, DmlGraphSupport::Supported)}, {REG_INFO( 19, DequantizeLinear, typeNameListTwo, supportedTypeListDequantizeLinear19, DmlGraphSupport::Supported)}, + {REG_INFO_VER( 21, DequantizeLinear, typeNameListTwo, supportedTypeListDequantizeLinear21, DmlGraphSupport::Supported)}, {REG_INFO_MS( 1, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear19, DmlGraphSupport::Supported)}, {REG_INFO_MS( 1, DequantizeLinear, typeNameListTwo, supportedTypeListDequantizeLinear19, DmlGraphSupport::Supported)}, {REG_INFO( 9, Sign, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported)}, @@ -1071,6 +1084,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 9, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, {REG_INFO( 13, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, {REG_INFO( 19, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, + {REG_INFO( 21, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, {REG_INFO_VER( 15, CastLike, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, {REG_INFO_VER( 19, CastLike, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, {REG_INFO( 7, MemcpyFromHost, typeNameListDefault, supportedTypeListAll)}, @@ -1084,6 +1098,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 13, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)}, {REG_INFO( 15, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)}, {REG_INFO( 19, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)}, + {REG_INFO( 21, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)}, {REG_INFO( 7, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)}, {REG_INFO( 13, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)}, {REG_INFO( 19, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)}, @@ -1246,6 +1261,8 @@ void RegisterDmlOperators(IMLOperatorRegistry* registry) // Scalars if (bool(supportedTypes & SupportedTensorDataTypes::Float32)) edgeDescs.push_back(TensorEdgeDesc()); + if (bool(supportedTypes & SupportedTensorDataTypes::UInt4 )) edgeDescs.push_back(TensorEdgeDesc<::MLUInt4x2>()); + if (bool(supportedTypes & SupportedTensorDataTypes::Int4 )) edgeDescs.push_back(TensorEdgeDesc<::MLInt4x2>()); if (bool(supportedTypes & SupportedTensorDataTypes::UInt8 )) edgeDescs.push_back(TensorEdgeDesc()); if (bool(supportedTypes & SupportedTensorDataTypes::Int8 )) edgeDescs.push_back(TensorEdgeDesc()); if (bool(supportedTypes & SupportedTensorDataTypes::UInt16 )) edgeDescs.push_back(TensorEdgeDesc()); diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h index 686cdbe774..ac77616cb9 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h @@ -5,6 +5,7 @@ #include "core/providers/dml/DmlExecutionProvider/inc/MLOperatorAuthor.h" #include "MLOperatorAuthorPrivate.h" +#include "core/framework/int4.h" #include #include @@ -20,6 +21,8 @@ namespace onnxruntime } using MLFloat16 = onnxruntime::MLFloat16; +using MLUInt4x2 = onnxruntime::Int4x2Base; +using MLInt4x2 = onnxruntime::Int4x2Base; // // Traits for numeric attribute types @@ -43,6 +46,18 @@ struct MLTypeTraits static const MLOperatorTensorDataType TensorType = MLOperatorTensorDataType::Int32; }; +template <> +struct MLTypeTraits> +{ + static const MLOperatorTensorDataType TensorType = MLOperatorTensorDataType::UInt4; +}; + +template <> +struct MLTypeTraits> +{ + static const MLOperatorTensorDataType TensorType = MLOperatorTensorDataType::Int4; +}; + template <> struct MLTypeTraits { diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index b775de0b39..aa61ee1dab 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -1696,9 +1696,11 @@ using ShapeInferenceHelper_DepthToSpace = DepthToSpaceHelper; using ShapeInferenceHelper_Squeeze7 = VersionedOpsetHelper; using ShapeInferenceHelper_Squeeze11 = VersionedOpsetHelper; using ShapeInferenceHelper_Squeeze13 = VersionedOpsetHelper; +using ShapeInferenceHelper_Squeeze21 = VersionedOpsetHelper; using ShapeInferenceHelper_Unsqueeze7 = VersionedOpsetHelper; using ShapeInferenceHelper_Unsqueeze11 = VersionedOpsetHelper; using ShapeInferenceHelper_Unsqueeze13 = VersionedOpsetHelper; +using ShapeInferenceHelper_Unsqueeze21 = VersionedOpsetHelper; using ShapeInferenceHelper_EyeLike = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Trilu = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Col2Im = Col2ImHelper; @@ -1708,6 +1710,7 @@ using ShapeInferenceHelper_Reshape7 = ReshapeHelper; using ShapeInferenceHelper_Reshape13 = ReshapeHelper; using ShapeInferenceHelper_Reshape14 = ReshapeHelper; using ShapeInferenceHelper_Reshape19 = ReshapeHelper; +using ShapeInferenceHelper_Reshape21 = ReshapeHelper; using ShapeInferenceHelper_ConstantOfShape = ConstantOfShapeHelper; using ShapeInferenceHelper_Tile = TileHelper; using ShapeInferenceHelper_Resize10 = VersionedOpsetHelper; @@ -1754,7 +1757,9 @@ using ShapeInferenceHelper_Asin = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Atan = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Affine = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_QuantizeLinear = GetOutputShapeAsInputShapeHelper; +using ShapeInferenceHelper_QuantizeLinear21 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_DequantizeLinear = GetOutputShapeAsInputShapeHelper; +using ShapeInferenceHelper_DequantizeLinear21 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_QLinearSigmoid = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_QAttention = QAttentionHelper; using ShapeInferenceHelper_Attention = AttentionHelper; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index f45c2b08db..26529c0d59 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -437,6 +437,17 @@ namespace OperatorHelper static const int sc_sinceVer_ReduceMin = 20; } + namespace OnnxOperatorSet21 + { + static const int sc_sinceVer_QuantizeLinear = 21; + static const int sc_sinceVer_DequantizeLinear = 21; + static const int sc_sinceVer_Squeeze = 21; + static const int sc_sinceVer_Unsqueeze = 21; + static const int sc_sinceVer_Reshape = 21; + static const int sc_sinceVer_Cast = 21; + static const int sc_sinceVer_Shape = 21; + } + namespace MsftOperatorSet1 { static const int sc_sinceVer_DmlFusedConv = 1; diff --git a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc index 386bd7d5f7..cc34f7e18c 100644 --- a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc @@ -47,6 +47,19 @@ TEST(DequantizeLinearOpTest, Int4) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } +// scalar scale with int4 +TEST(DequantizeLinearOpTest, Int4NoZeroPoint) { + OpTester test("DequantizeLinear", 21); + std::vector dims{5}; + constexpr int unused_val = 0; + + // Odd number of int4 values to test packing/unpacking + test.AddInput("x", dims, {Int4x2(-8, -3), Int4x2(1, 7), Int4x2(2, unused_val)}); + test.AddInput("x_scale", {}, {2.0f}); + test.AddOutput("y", dims, {-16.0f, -6.0f, 2.0f, 14.0f, 4.0f}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + // scalar zero & scale with uint4 TEST(DequantizeLinearOpTest, UInt4) { OpTester test("DequantizeLinear", 21); @@ -61,6 +74,19 @@ TEST(DequantizeLinearOpTest, UInt4) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } +// scalar scale with uint4 +TEST(DequantizeLinearOpTest, UInt4NoZeroPoint) { + OpTester test("DequantizeLinear", 21); + std::vector dims{5}; + constexpr int unused_val = 0; + + // Odd number of uint4 values to test packing/unpacking + test.AddInput("x", dims, {UInt4x2(0, 1), UInt4x2(3, 15), UInt4x2(2, unused_val)}); + test.AddInput("x_scale", {}, {2.0f}); + test.AddOutput("y", dims, {0.0f, 2.0f, 6.0f, 30.0f, 4.0f}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + // Test int16 DequantizeLinear (per tensor) TEST(DequantizeLinearOpTest, Int16) { OpTester test("DequantizeLinear", 21);