diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 96173b5a4e..46d9e217bf 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -958,7 +958,8 @@ Do not modify directly.*
|BitwiseNot|*in* X:**T**
*out* Y:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|BitwiseOr|*in* A:**T**
*in* B:**T**
*out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|BitwiseXor|*in* A:**T**
*in* B:**T**
*out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|Cast|*in* input:**T1**
*out* output:**T2**|19+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|Cast|*in* input:**T1**
*out* output:**T2**|21+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||19+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||13+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||9+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||6+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
@@ -992,7 +993,8 @@ Do not modify directly.*
|DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|DequantizeLinear|*in* x:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*out* y:**tensor(float)**
or
*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|19+|**T1** = tensor(int32), tensor(int8), tensor(uint8)
**T2** = tensor(float), tensor(float16)|
+|DequantizeLinear|*in* x:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*out* y:**tensor(float)**
or
*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|21+|**T1** = tensor(int4), tensor(int8), tensor(uint4), tensor(uint8)
**T2** = tensor(float), tensor(float16)|
+|||19+|**T1** = tensor(int32), tensor(int8), tensor(uint8)
**T2** = tensor(float), tensor(float16)|
|||13+|**T** = tensor(int32), tensor(int8), tensor(uint8)|
|||10+|**T** = tensor(int32), tensor(int8), tensor(uint8)|
|Div|*in* A:**T**
*in* B:**T**
*out* C:**T**|14+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
@@ -1107,8 +1109,8 @@ Do not modify directly.*
|MeanVarianceNormalization|*in* X:**T**
*out* Y:**T**
or
*in* input:**T**
*out* output:**T**|13+|**T** = tensor(float), tensor(float16)|
|||9+|**T** = tensor(float), tensor(float16)|
|||1+|**T** = tensor(float), tensor(float16)|
-|MemcpyFromHost|*in* X:**T**
*out* Y:**T**|1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|MemcpyToHost|*in* X:**T**
*out* Y:**T**|1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|MemcpyFromHost|*in* X:**T**
*out* Y:**T**|1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
+|MemcpyToHost|*in* X:**T**
*out* Y:**T**|1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
|Min|*in* data_0:**T**
*out* min:**T**|13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||12+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||8+|**T** = tensor(float), tensor(float16)|
@@ -1145,7 +1147,8 @@ Do not modify directly.*
|||7+|**T** = tensor(float), tensor(float16)|
|QLinearConv|*in* x:**T1**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T1**
*in* w:**T2**
*in* w_scale:**tensor(float)**
*in* w_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*in* B:**T4**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)
**T4** = tensor(int32)|
|QLinearMatMul|*in* a:**T1**
*in* a_scale:**TS**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**TS**
*in* b_zero_point:**T2**
*in* y_scale:**TS**
*in* y_zero_point:**T3**
*out* y:**T3**
or
*in* a:**T1**
*in* a_scale:**tensor(float)**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**tensor(float)**
*in* b_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)|
-|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**
or
*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|19+|**T1** = tensor(float), tensor(float16), tensor(int32)
**T2** = tensor(int8), tensor(uint8)|
+|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**
or
*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|21+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(int4), tensor(int8), tensor(uint4), tensor(uint8)|
+|||19+|**T1** = tensor(float), tensor(float16), tensor(int32)
**T2** = tensor(int8), tensor(uint8)|
|||13+|**T1** = tensor(float), tensor(int32)
**T2** = tensor(int8), tensor(uint8)|
|||10+|**T1** = tensor(float), tensor(int32)
**T2** = tensor(int8), tensor(uint8)|
|RNN|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*out* Y:**T**
*out* Y_h:**T**|14+|**T** = tensor(float), tensor(float16)|
@@ -1199,7 +1202,8 @@ Do not modify directly.*
|Relu|*in* X:**T**
*out* Y:**T**|14+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8)|
|||13+|**T** = tensor(float), tensor(float16)|
|||6+|**T** = tensor(float), tensor(float16)|
-|Reshape|*in* data:**T**
*in* shape:**tensor(int64)**
*out* reshaped:**T**
or
*in* data:**T**
*out* reshaped:**T**|19+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|Reshape|*in* data:**T**
*in* shape:**tensor(int64)**
*out* reshaped:**T**
or
*in* data:**T**
*out* reshaped:**T**|21+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||19+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||14+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||5+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
@@ -1230,10 +1234,11 @@ Do not modify directly.*
|SequenceErase|*in* input_sequence:**S**
*in* position:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
|SequenceInsert|*in* input_sequence:**S**
*in* tensor:**T**
*in* position:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
|SequenceLength|*in* input_sequence:**S**
*out* length:**I**|11+|**I** = tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
-|Shape|*in* data:**T**
*out* shape:**T1**|19+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|
-|||15+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|
-|||13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|
-|||1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|
+|Shape|*in* data:**T**
*out* shape:**T1**|21+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|
+|||19+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|
+|||15+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|
+|||13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|
+|||1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|
|Shrink|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint8)|
|Sigmoid|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(float), tensor(float16)|
|||6+|**T** = tensor(float), tensor(float16)|
@@ -1242,9 +1247,9 @@ Do not modify directly.*
|SimplifiedLayerNormalization|*in* X:**T**
*in* scale:**V**
*out* Y:**V**
*out* inv_std_var:**U**|1+|**T** = tensor(float), tensor(float16)
**U** = tensor(float), tensor(float16)
**V** = tensor(float), tensor(float16)|
|Sin|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(float), tensor(float16)|
|Sinh|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(float), tensor(float16)|
-|Size|*in* data:**T**
*out* size:**T1**|19+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|
-|||13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|
-|||1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|
+|Size|*in* data:**T**
*out* size:**T1**|19+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|
+|||13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|
+|||1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|
|Slice|*in* data:**T**
*in* starts:**Tind**
*in* ends:**Tind**
*in* axes:**Tind**
*in* steps:**Tind**
*out* output:**T**
or
*in* data:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)|
|||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)|
|||10+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)|
@@ -1262,7 +1267,8 @@ Do not modify directly.*
|||2+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Sqrt|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(float), tensor(float16)|
|||6+|**T** = tensor(float), tensor(float16)|
-|Squeeze|*in* data:**T**
*in* axes:**tensor(int64)**
*out* squeezed:**T**
or
*in* data:**T**
*out* squeezed:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|Squeeze|*in* data:**T**
*in* axes:**tensor(int64)**
*out* squeezed:**T**
or
*in* data:**T**
*out* squeezed:**T**|21+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Sub|*in* A:**T**
*in* B:**T**
*out* C:**T**|14+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
@@ -1284,7 +1290,8 @@ Do not modify directly.*
|Transpose|*in* data:**T**
*out* transposed:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Trilu|*in* input:**T**
*in* k:**tensor(int64)**
*out* output:**T**|14+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|Unsqueeze|*in* data:**T**
*in* axes:**tensor(int64)**
*out* expanded:**T**
or
*in* data:**T**
*out* expanded:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|Unsqueeze|*in* data:**T**
*in* axes:**tensor(int64)**
*out* expanded:**T**
or
*in* data:**T**
*out* expanded:**T**|21+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Upsample|*in* X:**T**
*in* scales:**tensor(float)**
*out* Y:**T**
or
*in* X:**T**
*out* Y:**T**|10+|**T** = tensor(float), tensor(float16)|
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommon.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommon.cpp
index 541254ffaf..aeb0bd9186 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommon.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommon.cpp
@@ -121,7 +121,7 @@ uint32_t GetSupportedDeviceDataTypeMask(IDMLDevice* dmlDevice)
uint32_t deviceTypeMask = 0u;
// Form the bitmask of all supported data types.
- for (uint32_t i = 0; i <= DML_TENSOR_DATA_TYPE_INT64; ++i)
+ for (uint32_t i = 0; i <= DML_TENSOR_DATA_TYPE_INT4; ++i)
{
DML_FEATURE_QUERY_TENSOR_DATA_TYPE_SUPPORT dataTypeQuery = { static_cast(i) };
DML_FEATURE_DATA_TENSOR_DATA_TYPE_SUPPORT dataTypeSupport = {};
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp
index 0a2a5bbcbe..26559b54bc 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp
@@ -247,6 +247,8 @@ namespace Windows::AI::MachineLearning::Adapter
}
ML_TENSOR_TYPE_CASE(float);
+ ML_TENSOR_TYPE_CASE(onnxruntime::Int4x2Base);
+ ML_TENSOR_TYPE_CASE(onnxruntime::Int4x2Base);
ML_TENSOR_TYPE_CASE(uint8_t);
ML_TENSOR_TYPE_CASE(int8_t);
ML_TENSOR_TYPE_CASE(uint16_t);
@@ -293,6 +295,8 @@ namespace Windows::AI::MachineLearning::Adapter
return onnxruntime::DataTypeImpl::GetTensorType();
ML_TENSOR_TYPE_CASE(float);
+ ML_TENSOR_TYPE_CASE(onnxruntime::Int4x2Base);
+ ML_TENSOR_TYPE_CASE(onnxruntime::Int4x2Base);
ML_TENSOR_TYPE_CASE(uint8_t);
ML_TENSOR_TYPE_CASE(int8_t);
ML_TENSOR_TYPE_CASE(uint16_t);
@@ -314,6 +318,8 @@ namespace Windows::AI::MachineLearning::Adapter
return onnxruntime::DataTypeImpl::GetSequenceTensorType();
ML_SEQUENCE_TENSOR_TYPE_CASE(float);
+ ML_SEQUENCE_TENSOR_TYPE_CASE(onnxruntime::Int4x2Base);
+ ML_SEQUENCE_TENSOR_TYPE_CASE(onnxruntime::Int4x2Base);
ML_SEQUENCE_TENSOR_TYPE_CASE(uint8_t);
ML_SEQUENCE_TENSOR_TYPE_CASE(int8_t);
ML_SEQUENCE_TENSOR_TYPE_CASE(uint16_t);
@@ -335,6 +341,8 @@ namespace Windows::AI::MachineLearning::Adapter
return onnxruntime::DataTypeImpl::GetType();
ML_PRIMITIVE_TYPE_CASE(float);
+ ML_PRIMITIVE_TYPE_CASE(onnxruntime::Int4x2Base);
+ ML_PRIMITIVE_TYPE_CASE(onnxruntime::Int4x2Base);
ML_PRIMITIVE_TYPE_CASE(uint8_t);
ML_PRIMITIVE_TYPE_CASE(int8_t);
ML_PRIMITIVE_TYPE_CASE(uint16_t);
@@ -364,6 +372,12 @@ namespace Windows::AI::MachineLearning::Adapter
case onnx::TensorProto_DataType_FLOAT:
return MLOperatorTensorDataType::Float;
+ case onnx::TensorProto_DataType_UINT4:
+ return MLOperatorTensorDataType::UInt4;
+
+ case onnx::TensorProto_DataType_INT4:
+ return MLOperatorTensorDataType::Int4;
+
case onnx::TensorProto_DataType_UINT8:
return MLOperatorTensorDataType::UInt8;
@@ -455,6 +469,12 @@ namespace Windows::AI::MachineLearning::Adapter
case MLOperatorTensorDataType::Float:
return "tensor(float)";
+ case MLOperatorTensorDataType::UInt4:
+ return "tensor(uint4)";
+
+ case MLOperatorTensorDataType::Int4:
+ return "tensor(int4)";
+
case MLOperatorTensorDataType::UInt8:
return "tensor(uint8)";
@@ -509,6 +529,12 @@ namespace Windows::AI::MachineLearning::Adapter
case MLOperatorTensorDataType::Float:
return "seq(tensor(float))";
+ case MLOperatorTensorDataType::UInt4:
+ return "seq(tensor(uint4))";
+
+ case MLOperatorTensorDataType::Int4:
+ return "seq(tensor(int4))";
+
case MLOperatorTensorDataType::UInt8:
return "seq(tensor(uint8))";
@@ -1518,7 +1544,7 @@ namespace Windows::AI::MachineLearning::Adapter
AbstractOperatorDesc abstractDesc = SchemaHelpers::ConvertOperatorDesc(*node);
m_graphNodeCreateInfo->nodes.push_back(std::make_unique(std::move(abstractDesc)));
}
-
+
// There can be operators (or kernels) which don't require any input.
assert(operatorGraphDesc->inputEdgeCount == 0 || operatorGraphDesc->inputEdges != nullptr);
m_graphNodeCreateInfo->inputEdges.insert(
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp
index 16bb10f004..412207fd3c 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp
@@ -528,6 +528,7 @@ public:
std::vector outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0);
const uint32_t outputShapeDimCount = gsl::narrow_cast(outputShape.size());
const DML_TENSOR_DATA_TYPE inputDataType = m_inputTensorDescs[0].GetDmlDataType();
+ const DML_TENSOR_DATA_TYPE outputDataType = m_outputTensorDescs[0].GetDmlDataType();
bool hasZeroPointTensor = kernelInfo.IsInputValid(2);
uint32_t axis = 0;
@@ -601,6 +602,141 @@ public:
}
};
+template
+class DmlOperatorQuantization21 : public DmlOperator
+{
+public:
+ DmlOperatorQuantization21(const MLOperatorKernelCreationContext& kernelInfo) : DmlOperator(kernelInfo)
+ {
+ ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 2 || kernelInfo.GetInputCount() == 3);
+ ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);
+
+ Initialize(kernelInfo, std::nullopt, std::nullopt);
+
+ std::vector outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0);
+ const uint32_t outputShapeDimCount = gsl::narrow_cast(outputShape.size());
+ const DML_TENSOR_DATA_TYPE inputDataType = m_inputTensorDescs[0].GetDmlDataType();
+ const DML_TENSOR_DATA_TYPE outputDataType = m_outputTensorDescs[0].GetDmlDataType();
+ bool hasZeroPointTensor = kernelInfo.IsInputValid(2);
+
+ std::vector inputDescs = GetDmlInputDescs();
+ std::vector outputDescs = GetDmlOutputDescs();
+
+ std::vector quantizationTensors;
+ quantizationTensors.push_back(inputDescs[1]);
+
+ const bool isSignedQuantization =
+ inputDataType == DML_TENSOR_DATA_TYPE_INT4 ||
+ outputDataType == DML_TENSOR_DATA_TYPE_INT4 ||
+ inputDataType == DML_TENSOR_DATA_TYPE_INT8 ||
+ outputDataType == DML_TENSOR_DATA_TYPE_INT8;
+
+ if (hasZeroPointTensor || isSignedQuantization)
+ {
+ if (hasZeroPointTensor)
+ {
+ quantizationTensors.push_back(inputDescs[2]);
+ }
+
+ TOperatorDesc opDesc = {};
+ opDesc.InputTensor = &inputDescs[0];
+ opDesc.QuantizationType = hasZeroPointTensor ? DML_QUANTIZATION_TYPE_SCALE_ZERO_POINT : DML_QUANTIZATION_TYPE_SCALE;
+ opDesc.QuantizationTensorCount = static_cast(quantizationTensors.size());
+ opDesc.QuantizationTensors = quantizationTensors.data();
+ opDesc.OutputTensor = &outputDescs[0];
+ SetDmlOperatorDesc({ApiTraits::OperatorDescTraits::Type, &opDesc}, kernelInfo);
+ }
+ else
+ {
+ // For unsigned quantization, DML uses the midpoint of the datatype when zero point isn't provided (e.g. 8 for uint4 and 128 for uint8)
+ // since this is the most sane default in theory that represents the midpoint of the quantization. This is also the default used by various
+ // other frameworks and quantization tools. But because ONNX uses a default zero point of 0 no matter if it's signed or unsigned, we need
+ // to generate a constant zero point tensor with a value of 0 to override dml's default value when it isn't provided.
+ auto zeroPointDataType = std::is_same_v ? outputDataType : inputDataType;
+
+ // DML doesn't support int4 FILL_VALUE_CONSTANT yet, so simply create an int8 scalar and reinterpret it to an int4 tensor
+ auto zeroPointInt8DataType = zeroPointDataType == DML_TENSOR_DATA_TYPE_INT4 ? DML_TENSOR_DATA_TYPE_INT8 : DML_TENSOR_DATA_TYPE_UINT8;
+
+ TensorDesc scalarTensorDesc(zeroPointInt8DataType, std::vector(m_inputTensorDescs[1].GetDimensionCount(), 1));
+ DML_TENSOR_DESC scalarDmlTensorDesc = scalarTensorDesc.GetDmlDesc();
+
+ // Create a tensor full of zeros
+ DML_FILL_VALUE_CONSTANT_OPERATOR_DESC zeroPointConstantDesc = {};
+ zeroPointConstantDesc.ValueDataType = zeroPointInt8DataType;
+ zeroPointConstantDesc.OutputTensor = &scalarDmlTensorDesc;
+ DML_OPERATOR_DESC zeroPointConstantDmlDesc = { DML_OPERATOR_FILL_VALUE_CONSTANT, &zeroPointConstantDesc };
+
+ // Broadcast the zero point tensor to match the scale tensor
+ TensorDesc broadcastedScalarTensorDesc(zeroPointDataType, m_inputTensorDescs[1].GetSizes(), std::vector(m_inputTensorDescs[1].GetDimensionCount(), 0));
+ quantizationTensors.push_back(broadcastedScalarTensorDesc.GetDmlDesc());
+
+ // Create the quantize/dequantize operator
+ TOperatorDesc quantizationOpDesc = {};
+ quantizationOpDesc.InputTensor = &inputDescs[0];
+ quantizationOpDesc.QuantizationType = DML_QUANTIZATION_TYPE_SCALE_ZERO_POINT;
+ quantizationOpDesc.QuantizationTensorCount = static_cast(quantizationTensors.size());
+ quantizationOpDesc.QuantizationTensors = quantizationTensors.data();
+ quantizationOpDesc.OutputTensor = &outputDescs[0];
+ DML_OPERATOR_DESC quantizationOpDmlDesc = { ApiTraits::OperatorDescTraits::Type, &quantizationOpDesc };
+
+ std::array opDescs = {
+ &zeroPointConstantDmlDesc,
+ &quantizationOpDmlDesc,
+ };
+
+ std::vector inputEdges;
+ inputEdges.reserve(2);
+
+ std::vector intermediateEdges;
+ intermediateEdges.reserve(1);
+
+ std::vector outputEdges;
+ outputEdges.reserve(1);
+
+ // Create an edge between the input tensor and the quantization operator
+ DML_INPUT_GRAPH_EDGE_DESC inputToQuantizeEdge{};
+ inputToQuantizeEdge.GraphInputIndex = 0;
+ inputToQuantizeEdge.ToNodeIndex = 1;
+ inputToQuantizeEdge.ToNodeInputIndex = 0;
+ inputEdges.push_back(inputToQuantizeEdge);
+
+ // Create an edge between the scale and the quantization operator
+ DML_INPUT_GRAPH_EDGE_DESC scaleToQuantizeEdge{};
+ scaleToQuantizeEdge.GraphInputIndex = 1;
+ scaleToQuantizeEdge.ToNodeIndex = 1;
+ scaleToQuantizeEdge.ToNodeInputIndex = 1;
+ inputEdges.push_back(scaleToQuantizeEdge);
+
+ // Create an edge between the generated zero point tensor and the quantization operator
+ DML_INTERMEDIATE_GRAPH_EDGE_DESC zeroPointToQuantizeEdge{};
+ zeroPointToQuantizeEdge.FromNodeIndex = 0;
+ zeroPointToQuantizeEdge.FromNodeOutputIndex = 0;
+ zeroPointToQuantizeEdge.ToNodeIndex = 1;
+ zeroPointToQuantizeEdge.ToNodeInputIndex = 2;
+ intermediateEdges.push_back(zeroPointToQuantizeEdge);
+
+ // Create an edge between the output of the quantization operator and the output of the graph
+ DML_OUTPUT_GRAPH_EDGE_DESC quantizeToOutputEdge{};
+ quantizeToOutputEdge.FromNodeIndex = 1;
+ quantizeToOutputEdge.FromNodeOutputIndex = 0;
+ quantizeToOutputEdge.GraphOutputIndex = 0;
+ outputEdges.push_back(quantizeToOutputEdge);
+
+ // Create the graph
+ MLOperatorGraphDesc operatorGraphDesc = {};
+ operatorGraphDesc.inputEdgeCount = gsl::narrow_cast(inputEdges.size());
+ operatorGraphDesc.inputEdges = inputEdges.data();
+ operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast(intermediateEdges.size());
+ operatorGraphDesc.intermediateEdges = intermediateEdges.data();
+ operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size());
+ operatorGraphDesc.outputEdges = outputEdges.data();
+ operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size());
+ operatorGraphDesc.nodes = opDescs.data();
+ SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelInfo);
+ }
+ }
+};
+
class DmlOperatorElementwiseIf : public DmlOperator
{
public:
@@ -777,18 +913,20 @@ DML_OP_DEFINE_CREATION_FUNCTION(Max, DmlOperatorElementwiseBinaryLo
DML_OP_DEFINE_CREATION_FUNCTION(Mean, DmlOperatorElementwiseMean);
// Operators with extra attributes:
-DML_OP_DEFINE_CREATION_FUNCTION(Clip7, DmlOperatorElementwiseClip7);
-DML_OP_DEFINE_CREATION_FUNCTION(Clip11, DmlOperatorElementwiseClip11);
-DML_OP_DEFINE_CREATION_FUNCTION(Clip12, DmlOperatorElementwiseClip12);
-DML_OP_DEFINE_CREATION_FUNCTION(Clip13, DmlOperatorElementwiseClip13);
-DML_OP_DEFINE_CREATION_FUNCTION(Pow, DmlOperatorElementwisePow);
-DML_OP_DEFINE_CREATION_FUNCTION(QuantizeLinear, DmlOperatorElementwiseQLinear);
-DML_OP_DEFINE_CREATION_FUNCTION(DequantizeLinear, DmlOperatorElementwiseQLinear);
-DML_OP_DEFINE_CREATION_FUNCTION(Where, DmlOperatorElementwiseIf);
-DML_OP_DEFINE_CREATION_FUNCTION(Mod, DmlOperatorElementwiseMod);
-DML_OP_DEFINE_CREATION_FUNCTION(BitShift, DmlOperatorElementwiseBitShift);
-DML_OP_DEFINE_CREATION_FUNCTION(IsInf, DmlOperatorElementwiseIsInf);
-DML_OP_DEFINE_CREATION_FUNCTION(Round, DmlOperatorElementwiseRound);
+DML_OP_DEFINE_CREATION_FUNCTION(Clip7, DmlOperatorElementwiseClip7);
+DML_OP_DEFINE_CREATION_FUNCTION(Clip11, DmlOperatorElementwiseClip11);
+DML_OP_DEFINE_CREATION_FUNCTION(Clip12, DmlOperatorElementwiseClip12);
+DML_OP_DEFINE_CREATION_FUNCTION(Clip13, DmlOperatorElementwiseClip13);
+DML_OP_DEFINE_CREATION_FUNCTION(Pow, DmlOperatorElementwisePow);
+DML_OP_DEFINE_CREATION_FUNCTION(QuantizeLinear, DmlOperatorElementwiseQLinear);
+DML_OP_DEFINE_CREATION_FUNCTION(DequantizeLinear, DmlOperatorElementwiseQLinear);
+DML_OP_DEFINE_CREATION_FUNCTION(QuantizeLinear21, DmlOperatorQuantization21);
+DML_OP_DEFINE_CREATION_FUNCTION(DequantizeLinear21, DmlOperatorQuantization21);
+DML_OP_DEFINE_CREATION_FUNCTION(Where, DmlOperatorElementwiseIf);
+DML_OP_DEFINE_CREATION_FUNCTION(Mod, DmlOperatorElementwiseMod);
+DML_OP_DEFINE_CREATION_FUNCTION(BitShift, DmlOperatorElementwiseBitShift);
+DML_OP_DEFINE_CREATION_FUNCTION(IsInf, DmlOperatorElementwiseIsInf);
+DML_OP_DEFINE_CREATION_FUNCTION(Round, DmlOperatorElementwiseRound);
// Fused operators:
DML_OP_DEFINE_CREATION_FUNCTION(DmlFusedAdd, DmlOperatorElementwiseBinary);
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
index cf8f0a4b2d..db8922439e 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
@@ -225,39 +225,43 @@ ONNX_OPERATOR_KERNEL_EX(
namespace Dml
{
-enum class SupportedTensorDataTypes : uint32_t
+enum class SupportedTensorDataTypes : uint64_t
{
- Undefined = 1<<0,
- Float32 = 1<<1,
- UInt8 = 1<<2,
- Int8 = 1<<3,
- UInt16 = 1<<4,
- Int16 = 1<<5,
- Int32 = 1<<6,
- Int64 = 1<<7,
- String = 1<<8,
- Bool = 1<<9,
- Float16 = 1<<10,
- Float64 = 1<<11,
- UInt32 = 1<<12,
- UInt64 = 1<<13,
- Complex64 = 1<<14,
- Complex128 = 1<<15,
- SequenceFloat32 = 1<<16,
- SequenceUInt8 = 1<<17,
- SequenceInt8 = 1<<18,
- SequenceUInt16 = 1<<19,
- SequenceInt16 = 1<<20,
- SequenceInt32 = 1<<21,
- SequenceInt64 = 1<<22,
- SequenceString = 1<<23,
- SequenceBool = 1<<24,
- SequenceFloat16 = 1<<25,
- SequenceFloat64 = 1<<26,
- SequenceUInt32 = 1<<27,
- SequenceUInt64 = 1<<28,
- SequenceComplex64 = 1<<29,
- SequenceComplex128 = 1<<30,
+ Undefined = 1LLU<<0,
+ Float32 = 1LLU<<1,
+ UInt4 = 1LLU<<2,
+ Int4 = 1LLU<<3,
+ UInt8 = 1LLU<<4,
+ Int8 = 1LLU<<5,
+ UInt16 = 1LLU<<6,
+ Int16 = 1LLU<<7,
+ Int32 = 1LLU<<8,
+ Int64 = 1LLU<<9,
+ String = 1LLU<<10,
+ Bool = 1LLU<<11,
+ Float16 = 1LLU<<12,
+ Float64 = 1LLU<<13,
+ UInt32 = 1LLU<<14,
+ UInt64 = 1LLU<<15,
+ Complex64 = 1LLU<<16,
+ Complex128 = 1LLU<<17,
+ SequenceFloat32 = 1LLU<<18,
+ SequenceUInt4 = 1LLU<<19,
+ SequenceInt4 = 1LLU<<20,
+ SequenceUInt8 = 1LLU<<21,
+ SequenceInt8 = 1LLU<<22,
+ SequenceUInt16 = 1LLU<<23,
+ SequenceInt16 = 1LLU<<24,
+ SequenceInt32 = 1LLU<<25,
+ SequenceInt64 = 1LLU<<26,
+ SequenceString = 1LLU<<27,
+ SequenceBool = 1LLU<<28,
+ SequenceFloat16 = 1LLU<<29,
+ SequenceFloat64 = 1LLU<<30,
+ SequenceUInt32 = 1LLU<<31,
+ SequenceUInt64 = 1LLU<<32,
+ SequenceComplex64 = 1LLU<<33,
+ SequenceComplex128 = 1LLU<<34,
Ints8to32 = UInt8|Int8|UInt16|Int16|UInt32|Int32,
Ints32to64 = UInt32|Int32|UInt64|Int64,
Ints8to64 = UInt8|Int8|UInt16|Int16|UInt32|Int32|UInt64|Int64,
@@ -273,7 +277,7 @@ enum class SupportedTensorDataTypes : uint32_t
Ints16Bit = UInt16|Int16,
Ints32Bit = UInt32|Int32,
Ints64Bit = UInt64|Int64,
- All = static_cast(-1),
+ All = static_cast(-1),
};
DEFINE_ENUM_FLAG_OPERATORS(Dml::SupportedTensorDataTypes);
@@ -463,7 +467,9 @@ DML_OP_EXTERN_CREATION_FUNCTION(DmlFusedMatMul);
DML_OP_EXTERN_CREATION_FUNCTION(DmlFusedAdd);
DML_OP_EXTERN_CREATION_FUNCTION(DmlFusedSum);
DML_OP_EXTERN_CREATION_FUNCTION(QuantizeLinear);
+DML_OP_EXTERN_CREATION_FUNCTION(QuantizeLinear21);
DML_OP_EXTERN_CREATION_FUNCTION(DequantizeLinear);
+DML_OP_EXTERN_CREATION_FUNCTION(DequantizeLinear21);
DML_OP_EXTERN_CREATION_FUNCTION(QLinearSigmoid);
DML_OP_EXTERN_CREATION_FUNCTION(Sign);
DML_OP_EXTERN_CREATION_FUNCTION(IsNaN);
@@ -599,8 +605,10 @@ constexpr static std::array supportedTypeListScatte
constexpr static std::array supportedTypeListSlice10 = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 };
constexpr static std::array supportedTypeListQuantizeLinear = { SupportedTensorDataTypes::Float32 | SupportedTensorDataTypes::Int32, SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 };
constexpr static std::array supportedTypeListQuantizeLinear19 = { SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int32, SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 };
+constexpr static std::array supportedTypeListQuantizeLinear21 = { SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::UInt4 | SupportedTensorDataTypes::Int4 | SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 };
constexpr static std::array supportedTypeListDequantizeLinear = { SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 | SupportedTensorDataTypes::Int32 };
constexpr static std::array supportedTypeListDequantizeLinear19 = { SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 | SupportedTensorDataTypes::Int32, SupportedTensorDataTypes::Float16to32 };
+constexpr static std::array supportedTypeListDequantizeLinear21 = { SupportedTensorDataTypes::UInt4 | SupportedTensorDataTypes::Int4 | SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8, SupportedTensorDataTypes::Float16to32 };
constexpr static std::array supportedTypeListIsNan = { SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Bool };
constexpr static std::array supportedTypeListIsInf = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::Bool };
constexpr static std::array supportedTypeListConstantOfShape = { SupportedTensorDataTypes::Int64, SupportedTensorDataTypes::AllScalars };
@@ -848,13 +856,16 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO_COPY( 7, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
{REG_INFO_COPY(11, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
{REG_INFO_COPY(13, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
+ {REG_INFO_COPY(21, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
{REG_INFO_COPY( 7, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
{REG_INFO_COPY(11, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
{REG_INFO_COPY(13, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
+ {REG_INFO_COPY(21, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
{REG_INFO_COPY( 7, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
{REG_INFO_COPY(13, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
{REG_INFO_COPY(14, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
{REG_INFO_COPY(19, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
+ {REG_INFO_COPY(21, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
// Elementwise
{REG_INFO( 7, Sqrt, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
@@ -915,9 +926,11 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 10, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear, DmlGraphSupport::Supported)},
{REG_INFO( 13, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear, DmlGraphSupport::Supported)},
{REG_INFO( 19, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear19, DmlGraphSupport::Supported)},
+ {REG_INFO_VER( 21, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear21, DmlGraphSupport::Supported)},
{REG_INFO( 10, DequantizeLinear, typeNameListDefault, supportedTypeListDequantizeLinear, DmlGraphSupport::Supported)},
{REG_INFO( 13, DequantizeLinear, typeNameListDefault, supportedTypeListDequantizeLinear, DmlGraphSupport::Supported)},
{REG_INFO( 19, DequantizeLinear, typeNameListTwo, supportedTypeListDequantizeLinear19, DmlGraphSupport::Supported)},
+ {REG_INFO_VER( 21, DequantizeLinear, typeNameListTwo, supportedTypeListDequantizeLinear21, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear19, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, DequantizeLinear, typeNameListTwo, supportedTypeListDequantizeLinear19, DmlGraphSupport::Supported)},
{REG_INFO( 9, Sign, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported)},
@@ -1071,6 +1084,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 9, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
{REG_INFO( 13, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
{REG_INFO( 19, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
+ {REG_INFO( 21, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
{REG_INFO_VER( 15, CastLike, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
{REG_INFO_VER( 19, CastLike, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
{REG_INFO( 7, MemcpyFromHost, typeNameListDefault, supportedTypeListAll)},
@@ -1084,6 +1098,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 13, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)},
{REG_INFO( 15, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)},
{REG_INFO( 19, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)},
+ {REG_INFO( 21, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)},
{REG_INFO( 7, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)},
{REG_INFO( 13, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)},
{REG_INFO( 19, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)},
@@ -1246,6 +1261,8 @@ void RegisterDmlOperators(IMLOperatorRegistry* registry)
// Scalars
if (bool(supportedTypes & SupportedTensorDataTypes::Float32)) edgeDescs.push_back(TensorEdgeDesc());
+ if (bool(supportedTypes & SupportedTensorDataTypes::UInt4 )) edgeDescs.push_back(TensorEdgeDesc<::MLUInt4x2>());
+ if (bool(supportedTypes & SupportedTensorDataTypes::Int4 )) edgeDescs.push_back(TensorEdgeDesc<::MLInt4x2>());
if (bool(supportedTypes & SupportedTensorDataTypes::UInt8 )) edgeDescs.push_back(TensorEdgeDesc());
if (bool(supportedTypes & SupportedTensorDataTypes::Int8 )) edgeDescs.push_back(TensorEdgeDesc());
if (bool(supportedTypes & SupportedTensorDataTypes::UInt16 )) edgeDescs.push_back(TensorEdgeDesc());
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h
index 686cdbe774..ac77616cb9 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h
@@ -5,6 +5,7 @@
#include "core/providers/dml/DmlExecutionProvider/inc/MLOperatorAuthor.h"
#include "MLOperatorAuthorPrivate.h"
+#include "core/framework/int4.h"
#include
#include
@@ -20,6 +21,8 @@ namespace onnxruntime
}
using MLFloat16 = onnxruntime::MLFloat16;
+using MLUInt4x2 = onnxruntime::Int4x2Base;
+using MLInt4x2 = onnxruntime::Int4x2Base;
//
// Traits for numeric attribute types
@@ -43,6 +46,18 @@ struct MLTypeTraits
static const MLOperatorTensorDataType TensorType = MLOperatorTensorDataType::Int32;
};
+template <>
+struct MLTypeTraits>
+{
+ static const MLOperatorTensorDataType TensorType = MLOperatorTensorDataType::UInt4;
+};
+
+template <>
+struct MLTypeTraits>
+{
+ static const MLOperatorTensorDataType TensorType = MLOperatorTensorDataType::Int4;
+};
+
template <>
struct MLTypeTraits
{
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
index b775de0b39..aa61ee1dab 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
@@ -1696,9 +1696,11 @@ using ShapeInferenceHelper_DepthToSpace = DepthToSpaceHelper;
using ShapeInferenceHelper_Squeeze7 = VersionedOpsetHelper;
using ShapeInferenceHelper_Squeeze11 = VersionedOpsetHelper;
using ShapeInferenceHelper_Squeeze13 = VersionedOpsetHelper;
+using ShapeInferenceHelper_Squeeze21 = VersionedOpsetHelper;
using ShapeInferenceHelper_Unsqueeze7 = VersionedOpsetHelper;
using ShapeInferenceHelper_Unsqueeze11 = VersionedOpsetHelper;
using ShapeInferenceHelper_Unsqueeze13 = VersionedOpsetHelper;
+using ShapeInferenceHelper_Unsqueeze21 = VersionedOpsetHelper;
using ShapeInferenceHelper_EyeLike = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Trilu = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Col2Im = Col2ImHelper;
@@ -1708,6 +1710,7 @@ using ShapeInferenceHelper_Reshape7 = ReshapeHelper;
using ShapeInferenceHelper_Reshape13 = ReshapeHelper;
using ShapeInferenceHelper_Reshape14 = ReshapeHelper;
using ShapeInferenceHelper_Reshape19 = ReshapeHelper;
+using ShapeInferenceHelper_Reshape21 = ReshapeHelper;
using ShapeInferenceHelper_ConstantOfShape = ConstantOfShapeHelper;
using ShapeInferenceHelper_Tile = TileHelper;
using ShapeInferenceHelper_Resize10 = VersionedOpsetHelper;
@@ -1754,7 +1757,9 @@ using ShapeInferenceHelper_Asin = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Atan = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Affine = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_QuantizeLinear = GetOutputShapeAsInputShapeHelper;
+using ShapeInferenceHelper_QuantizeLinear21 = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_DequantizeLinear = GetOutputShapeAsInputShapeHelper;
+using ShapeInferenceHelper_DequantizeLinear21 = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_QLinearSigmoid = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_QAttention = QAttentionHelper;
using ShapeInferenceHelper_Attention = AttentionHelper;
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
index f45c2b08db..26529c0d59 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
@@ -437,6 +437,17 @@ namespace OperatorHelper
static const int sc_sinceVer_ReduceMin = 20;
}
+ namespace OnnxOperatorSet21
+ {
+ static const int sc_sinceVer_QuantizeLinear = 21;
+ static const int sc_sinceVer_DequantizeLinear = 21;
+ static const int sc_sinceVer_Squeeze = 21;
+ static const int sc_sinceVer_Unsqueeze = 21;
+ static const int sc_sinceVer_Reshape = 21;
+ static const int sc_sinceVer_Cast = 21;
+ static const int sc_sinceVer_Shape = 21;
+ }
+
namespace MsftOperatorSet1
{
static const int sc_sinceVer_DmlFusedConv = 1;
diff --git a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc
index 386bd7d5f7..cc34f7e18c 100644
--- a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc
+++ b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc
@@ -47,6 +47,19 @@ TEST(DequantizeLinearOpTest, Int4) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
}
+// scalar scale with int4
+TEST(DequantizeLinearOpTest, Int4NoZeroPoint) {
+ OpTester test("DequantizeLinear", 21);
+ std::vector dims{5};
+ constexpr int unused_val = 0;
+
+ // Odd number of int4 values to test packing/unpacking
+ test.AddInput("x", dims, {Int4x2(-8, -3), Int4x2(1, 7), Int4x2(2, unused_val)});
+ test.AddInput("x_scale", {}, {2.0f});
+ test.AddOutput("y", dims, {-16.0f, -6.0f, 2.0f, 14.0f, 4.0f});
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
+}
+
// scalar zero & scale with uint4
TEST(DequantizeLinearOpTest, UInt4) {
OpTester test("DequantizeLinear", 21);
@@ -61,6 +74,19 @@ TEST(DequantizeLinearOpTest, UInt4) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
}
+// scalar scale with uint4
+TEST(DequantizeLinearOpTest, UInt4NoZeroPoint) {
+ OpTester test("DequantizeLinear", 21);
+ std::vector dims{5};
+ constexpr int unused_val = 0;
+
+ // Odd number of uint4 values to test packing/unpacking
+ test.AddInput("x", dims, {UInt4x2(0, 1), UInt4x2(3, 15), UInt4x2(2, unused_val)});
+ test.AddInput("x_scale", {}, {2.0f});
+ test.AddOutput("y", dims, {0.0f, 2.0f, 6.0f, 30.0f, 4.0f});
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
+}
+
// Test int16 DequantizeLinear (per tensor)
TEST(DequantizeLinearOpTest, Int16) {
OpTester test("DequantizeLinear", 21);