[DML] Add int4 QDQ (#21592)

This commit is contained in:
Patrice Vignola 2024-08-20 23:44:58 -07:00 committed by GitHub
parent 12f426c63f
commit de6ebcbb54
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 307 additions and 62 deletions

View file

@ -958,7 +958,8 @@ Do not modify directly.*
|BitwiseNot|*in* X:**T**<br> *out* Y:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|BitwiseOr|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|BitwiseXor|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Cast|*in* input:**T1**<br> *out* output:**T2**|19+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Cast|*in* input:**T1**<br> *out* output:**T2**|21+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||19+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||13+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||9+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||6+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
@ -992,7 +993,8 @@ Do not modify directly.*
|DepthToSpace|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|DequantizeLinear|*in* x:**T**<br> *in* x_scale:**tensor(float)**<br> *in* x_zero_point:**T**<br> *out* y:**tensor(float)**<br><br>or<br><br>*in* x:**T1**<br> *in* x_scale:**T2**<br> *in* x_zero_point:**T1**<br> *out* y:**T2**|19+|**T1** = tensor(int32), tensor(int8), tensor(uint8)<br/> **T2** = tensor(float), tensor(float16)|
|DequantizeLinear|*in* x:**T**<br> *in* x_scale:**tensor(float)**<br> *in* x_zero_point:**T**<br> *out* y:**tensor(float)**<br><br>or<br><br>*in* x:**T1**<br> *in* x_scale:**T2**<br> *in* x_zero_point:**T1**<br> *out* y:**T2**|21+|**T1** = tensor(int4), tensor(int8), tensor(uint4), tensor(uint8)<br/> **T2** = tensor(float), tensor(float16)|
|||19+|**T1** = tensor(int32), tensor(int8), tensor(uint8)<br/> **T2** = tensor(float), tensor(float16)|
|||13+|**T** = tensor(int32), tensor(int8), tensor(uint8)|
|||10+|**T** = tensor(int32), tensor(int8), tensor(uint8)|
|Div|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|14+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
@ -1107,8 +1109,8 @@ Do not modify directly.*
|MeanVarianceNormalization|*in* X:**T**<br> *out* Y:**T**<br><br>or<br><br>*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(float), tensor(float16)|
|||9+|**T** = tensor(float), tensor(float16)|
|||1+|**T** = tensor(float), tensor(float16)|
|MemcpyFromHost|*in* X:**T**<br> *out* Y:**T**|1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|MemcpyToHost|*in* X:**T**<br> *out* Y:**T**|1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|MemcpyFromHost|*in* X:**T**<br> *out* Y:**T**|1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
|MemcpyToHost|*in* X:**T**<br> *out* Y:**T**|1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
|Min|*in* data_0:**T**<br> *out* min:**T**|13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||12+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||8+|**T** = tensor(float), tensor(float16)|
@ -1145,7 +1147,8 @@ Do not modify directly.*
|||7+|**T** = tensor(float), tensor(float16)|
|QLinearConv|*in* x:**T1**<br> *in* x_scale:**tensor(float)**<br> *in* x_zero_point:**T1**<br> *in* w:**T2**<br> *in* w_scale:**tensor(float)**<br> *in* w_zero_point:**T2**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T3**<br> *in* B:**T4**<br> *out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(int8), tensor(uint8)<br/> **T4** = tensor(int32)|
|QLinearMatMul|*in* a:**T1**<br> *in* a_scale:**TS**<br> *in* a_zero_point:**T1**<br> *in* b:**T2**<br> *in* b_scale:**TS**<br> *in* b_zero_point:**T2**<br> *in* y_scale:**TS**<br> *in* y_zero_point:**T3**<br> *out* y:**T3**<br><br>or<br><br>*in* a:**T1**<br> *in* a_scale:**tensor(float)**<br> *in* a_zero_point:**T1**<br> *in* b:**T2**<br> *in* b_scale:**tensor(float)**<br> *in* b_zero_point:**T2**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T3**<br> *out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(int8), tensor(uint8)|
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**<br><br>or<br><br>*in* x:**T1**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|19+|**T1** = tensor(float), tensor(float16), tensor(int32)<br/> **T2** = tensor(int8), tensor(uint8)|
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**<br><br>or<br><br>*in* x:**T1**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|21+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(int4), tensor(int8), tensor(uint4), tensor(uint8)|
|||19+|**T1** = tensor(float), tensor(float16), tensor(int32)<br/> **T2** = tensor(int8), tensor(uint8)|
|||13+|**T1** = tensor(float), tensor(int32)<br/> **T2** = tensor(int8), tensor(uint8)|
|||10+|**T1** = tensor(float), tensor(int32)<br/> **T2** = tensor(int8), tensor(uint8)|
|RNN|*in* X:**T**<br> *in* W:**T**<br> *in* R:**T**<br> *in* B:**T**<br> *in* sequence_lens:**T1**<br> *in* initial_h:**T**<br> *out* Y:**T**<br> *out* Y_h:**T**|14+|**T** = tensor(float), tensor(float16)|
@ -1199,7 +1202,8 @@ Do not modify directly.*
|Relu|*in* X:**T**<br> *out* Y:**T**|14+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8)|
|||13+|**T** = tensor(float), tensor(float16)|
|||6+|**T** = tensor(float), tensor(float16)|
|Reshape|*in* data:**T**<br> *in* shape:**tensor(int64)**<br> *out* reshaped:**T**<br><br>or<br><br>*in* data:**T**<br> *out* reshaped:**T**|19+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Reshape|*in* data:**T**<br> *in* shape:**tensor(int64)**<br> *out* reshaped:**T**<br><br>or<br><br>*in* data:**T**<br> *out* reshaped:**T**|21+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||19+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||14+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||5+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
@ -1230,10 +1234,11 @@ Do not modify directly.*
|SequenceErase|*in* input_sequence:**S**<br> *in* position:**I**<br> *out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)<br/> **S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
|SequenceInsert|*in* input_sequence:**S**<br> *in* tensor:**T**<br> *in* position:**I**<br> *out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)<br/> **S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
|SequenceLength|*in* input_sequence:**S**<br> *out* length:**I**|11+|**I** = tensor(int64)<br/> **S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
|Shape|*in* data:**T**<br> *out* shape:**T1**|19+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|||15+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|||13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|||1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|Shape|*in* data:**T**<br> *out* shape:**T1**|21+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|||19+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|||15+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|||13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|||1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|Shrink|*in* input:**T**<br> *out* output:**T**|9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint8)|
|Sigmoid|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(float), tensor(float16)|
|||6+|**T** = tensor(float), tensor(float16)|
@ -1242,9 +1247,9 @@ Do not modify directly.*
|SimplifiedLayerNormalization|*in* X:**T**<br> *in* scale:**V**<br> *out* Y:**V**<br> *out* inv_std_var:**U**|1+|**T** = tensor(float), tensor(float16)<br/> **U** = tensor(float), tensor(float16)<br/> **V** = tensor(float), tensor(float16)|
|Sin|*in* input:**T**<br> *out* output:**T**|7+|**T** = tensor(float), tensor(float16)|
|Sinh|*in* input:**T**<br> *out* output:**T**|9+|**T** = tensor(float), tensor(float16)|
|Size|*in* data:**T**<br> *out* size:**T1**|19+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|||13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|||1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|Size|*in* data:**T**<br> *out* size:**T1**|19+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|||13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|||1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|Slice|*in* data:**T**<br> *in* starts:**Tind**<br> *in* ends:**Tind**<br> *in* axes:**Tind**<br> *in* steps:**Tind**<br> *out* output:**T**<br><br>or<br><br>*in* data:**T**<br> *out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|||10+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
@ -1262,7 +1267,8 @@ Do not modify directly.*
|||2+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Sqrt|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(float), tensor(float16)|
|||6+|**T** = tensor(float), tensor(float16)|
|Squeeze|*in* data:**T**<br> *in* axes:**tensor(int64)**<br> *out* squeezed:**T**<br><br>or<br><br>*in* data:**T**<br> *out* squeezed:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Squeeze|*in* data:**T**<br> *in* axes:**tensor(int64)**<br> *out* squeezed:**T**<br><br>or<br><br>*in* data:**T**<br> *out* squeezed:**T**|21+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Sub|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|14+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
@ -1284,7 +1290,8 @@ Do not modify directly.*
|Transpose|*in* data:**T**<br> *out* transposed:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Trilu|*in* input:**T**<br> *in* k:**tensor(int64)**<br> *out* output:**T**|14+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Unsqueeze|*in* data:**T**<br> *in* axes:**tensor(int64)**<br> *out* expanded:**T**<br><br>or<br><br>*in* data:**T**<br> *out* expanded:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Unsqueeze|*in* data:**T**<br> *in* axes:**tensor(int64)**<br> *out* expanded:**T**<br><br>or<br><br>*in* data:**T**<br> *out* expanded:**T**|21+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Upsample|*in* X:**T**<br> *in* scales:**tensor(float)**<br> *out* Y:**T**<br><br>or<br><br>*in* X:**T**<br> *out* Y:**T**|10+|**T** = tensor(float), tensor(float16)|

View file

@ -121,7 +121,7 @@ uint32_t GetSupportedDeviceDataTypeMask(IDMLDevice* dmlDevice)
uint32_t deviceTypeMask = 0u;
// Form the bitmask of all supported data types.
for (uint32_t i = 0; i <= DML_TENSOR_DATA_TYPE_INT64; ++i)
for (uint32_t i = 0; i <= DML_TENSOR_DATA_TYPE_INT4; ++i)
{
DML_FEATURE_QUERY_TENSOR_DATA_TYPE_SUPPORT dataTypeQuery = { static_cast<DML_TENSOR_DATA_TYPE>(i) };
DML_FEATURE_DATA_TENSOR_DATA_TYPE_SUPPORT dataTypeSupport = {};

View file

@ -247,6 +247,8 @@ namespace Windows::AI::MachineLearning::Adapter
}
ML_TENSOR_TYPE_CASE(float);
ML_TENSOR_TYPE_CASE(onnxruntime::Int4x2Base<false>);
ML_TENSOR_TYPE_CASE(onnxruntime::Int4x2Base<true>);
ML_TENSOR_TYPE_CASE(uint8_t);
ML_TENSOR_TYPE_CASE(int8_t);
ML_TENSOR_TYPE_CASE(uint16_t);
@ -293,6 +295,8 @@ namespace Windows::AI::MachineLearning::Adapter
return onnxruntime::DataTypeImpl::GetTensorType<std::string>();
ML_TENSOR_TYPE_CASE(float);
ML_TENSOR_TYPE_CASE(onnxruntime::Int4x2Base<false>);
ML_TENSOR_TYPE_CASE(onnxruntime::Int4x2Base<true>);
ML_TENSOR_TYPE_CASE(uint8_t);
ML_TENSOR_TYPE_CASE(int8_t);
ML_TENSOR_TYPE_CASE(uint16_t);
@ -314,6 +318,8 @@ namespace Windows::AI::MachineLearning::Adapter
return onnxruntime::DataTypeImpl::GetSequenceTensorType<std::string>();
ML_SEQUENCE_TENSOR_TYPE_CASE(float);
ML_SEQUENCE_TENSOR_TYPE_CASE(onnxruntime::Int4x2Base<false>);
ML_SEQUENCE_TENSOR_TYPE_CASE(onnxruntime::Int4x2Base<true>);
ML_SEQUENCE_TENSOR_TYPE_CASE(uint8_t);
ML_SEQUENCE_TENSOR_TYPE_CASE(int8_t);
ML_SEQUENCE_TENSOR_TYPE_CASE(uint16_t);
@ -335,6 +341,8 @@ namespace Windows::AI::MachineLearning::Adapter
return onnxruntime::DataTypeImpl::GetType<std::string>();
ML_PRIMITIVE_TYPE_CASE(float);
ML_PRIMITIVE_TYPE_CASE(onnxruntime::Int4x2Base<false>);
ML_PRIMITIVE_TYPE_CASE(onnxruntime::Int4x2Base<true>);
ML_PRIMITIVE_TYPE_CASE(uint8_t);
ML_PRIMITIVE_TYPE_CASE(int8_t);
ML_PRIMITIVE_TYPE_CASE(uint16_t);
@ -364,6 +372,12 @@ namespace Windows::AI::MachineLearning::Adapter
case onnx::TensorProto_DataType_FLOAT:
return MLOperatorTensorDataType::Float;
case onnx::TensorProto_DataType_UINT4:
return MLOperatorTensorDataType::UInt4;
case onnx::TensorProto_DataType_INT4:
return MLOperatorTensorDataType::Int4;
case onnx::TensorProto_DataType_UINT8:
return MLOperatorTensorDataType::UInt8;
@ -455,6 +469,12 @@ namespace Windows::AI::MachineLearning::Adapter
case MLOperatorTensorDataType::Float:
return "tensor(float)";
case MLOperatorTensorDataType::UInt4:
return "tensor(uint4)";
case MLOperatorTensorDataType::Int4:
return "tensor(int4)";
case MLOperatorTensorDataType::UInt8:
return "tensor(uint8)";
@ -509,6 +529,12 @@ namespace Windows::AI::MachineLearning::Adapter
case MLOperatorTensorDataType::Float:
return "seq(tensor(float))";
case MLOperatorTensorDataType::UInt4:
return "seq(tensor(uint4))";
case MLOperatorTensorDataType::Int4:
return "seq(tensor(int4))";
case MLOperatorTensorDataType::UInt8:
return "seq(tensor(uint8))";
@ -1518,7 +1544,7 @@ namespace Windows::AI::MachineLearning::Adapter
AbstractOperatorDesc abstractDesc = SchemaHelpers::ConvertOperatorDesc(*node);
m_graphNodeCreateInfo->nodes.push_back(std::make_unique<AbstractOperatorDesc>(std::move(abstractDesc)));
}
// There can be operators (or kernels) which don't require any input.
assert(operatorGraphDesc->inputEdgeCount == 0 || operatorGraphDesc->inputEdges != nullptr);
m_graphNodeCreateInfo->inputEdges.insert(

View file

@ -528,6 +528,7 @@ public:
std::vector<uint32_t> outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0);
const uint32_t outputShapeDimCount = gsl::narrow_cast<uint32_t>(outputShape.size());
const DML_TENSOR_DATA_TYPE inputDataType = m_inputTensorDescs[0].GetDmlDataType();
const DML_TENSOR_DATA_TYPE outputDataType = m_outputTensorDescs[0].GetDmlDataType();
bool hasZeroPointTensor = kernelInfo.IsInputValid(2);
uint32_t axis = 0;
@ -601,6 +602,141 @@ public:
}
};
template <typename TOperatorDesc>
class DmlOperatorQuantization21 : public DmlOperator
{
public:
DmlOperatorQuantization21(const MLOperatorKernelCreationContext& kernelInfo) : DmlOperator(kernelInfo)
{
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 2 || kernelInfo.GetInputCount() == 3);
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);
Initialize(kernelInfo, std::nullopt, std::nullopt);
std::vector<uint32_t> outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0);
const uint32_t outputShapeDimCount = gsl::narrow_cast<uint32_t>(outputShape.size());
const DML_TENSOR_DATA_TYPE inputDataType = m_inputTensorDescs[0].GetDmlDataType();
const DML_TENSOR_DATA_TYPE outputDataType = m_outputTensorDescs[0].GetDmlDataType();
bool hasZeroPointTensor = kernelInfo.IsInputValid(2);
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
std::vector<DML_TENSOR_DESC> quantizationTensors;
quantizationTensors.push_back(inputDescs[1]);
const bool isSignedQuantization =
inputDataType == DML_TENSOR_DATA_TYPE_INT4 ||
outputDataType == DML_TENSOR_DATA_TYPE_INT4 ||
inputDataType == DML_TENSOR_DATA_TYPE_INT8 ||
outputDataType == DML_TENSOR_DATA_TYPE_INT8;
if (hasZeroPointTensor || isSignedQuantization)
{
if (hasZeroPointTensor)
{
quantizationTensors.push_back(inputDescs[2]);
}
TOperatorDesc opDesc = {};
opDesc.InputTensor = &inputDescs[0];
opDesc.QuantizationType = hasZeroPointTensor ? DML_QUANTIZATION_TYPE_SCALE_ZERO_POINT : DML_QUANTIZATION_TYPE_SCALE;
opDesc.QuantizationTensorCount = static_cast<uint32_t>(quantizationTensors.size());
opDesc.QuantizationTensors = quantizationTensors.data();
opDesc.OutputTensor = &outputDescs[0];
SetDmlOperatorDesc({ApiTraits::OperatorDescTraits<TOperatorDesc>::Type, &opDesc}, kernelInfo);
}
else
{
// For unsigned quantization, DML uses the midpoint of the datatype when zero point isn't provided (e.g. 8 for uint4 and 128 for uint8)
// since this is the most sane default in theory that represents the midpoint of the quantization. This is also the default used by various
// other frameworks and quantization tools. But because ONNX uses a default zero point of 0 no matter if it's signed or unsigned, we need
// to generate a constant zero point tensor with a value of 0 to override dml's default value when it isn't provided.
auto zeroPointDataType = std::is_same_v<TOperatorDesc, DML_QUANTIZE_OPERATOR_DESC> ? outputDataType : inputDataType;
// DML doesn't support int4 FILL_VALUE_CONSTANT yet, so simply create an int8 scalar and reinterpret it to an int4 tensor
auto zeroPointInt8DataType = zeroPointDataType == DML_TENSOR_DATA_TYPE_INT4 ? DML_TENSOR_DATA_TYPE_INT8 : DML_TENSOR_DATA_TYPE_UINT8;
TensorDesc scalarTensorDesc(zeroPointInt8DataType, std::vector<uint32_t>(m_inputTensorDescs[1].GetDimensionCount(), 1));
DML_TENSOR_DESC scalarDmlTensorDesc = scalarTensorDesc.GetDmlDesc();
// Create a tensor full of zeros
DML_FILL_VALUE_CONSTANT_OPERATOR_DESC zeroPointConstantDesc = {};
zeroPointConstantDesc.ValueDataType = zeroPointInt8DataType;
zeroPointConstantDesc.OutputTensor = &scalarDmlTensorDesc;
DML_OPERATOR_DESC zeroPointConstantDmlDesc = { DML_OPERATOR_FILL_VALUE_CONSTANT, &zeroPointConstantDesc };
// Broadcast the zero point tensor to match the scale tensor
TensorDesc broadcastedScalarTensorDesc(zeroPointDataType, m_inputTensorDescs[1].GetSizes(), std::vector<uint32_t>(m_inputTensorDescs[1].GetDimensionCount(), 0));
quantizationTensors.push_back(broadcastedScalarTensorDesc.GetDmlDesc());
// Create the quantize/dequantize operator
TOperatorDesc quantizationOpDesc = {};
quantizationOpDesc.InputTensor = &inputDescs[0];
quantizationOpDesc.QuantizationType = DML_QUANTIZATION_TYPE_SCALE_ZERO_POINT;
quantizationOpDesc.QuantizationTensorCount = static_cast<uint32_t>(quantizationTensors.size());
quantizationOpDesc.QuantizationTensors = quantizationTensors.data();
quantizationOpDesc.OutputTensor = &outputDescs[0];
DML_OPERATOR_DESC quantizationOpDmlDesc = { ApiTraits::OperatorDescTraits<TOperatorDesc>::Type, &quantizationOpDesc };
std::array<const DML_OPERATOR_DESC*, 2> opDescs = {
&zeroPointConstantDmlDesc,
&quantizationOpDmlDesc,
};
std::vector<DML_INPUT_GRAPH_EDGE_DESC> inputEdges;
inputEdges.reserve(2);
std::vector<DML_INTERMEDIATE_GRAPH_EDGE_DESC> intermediateEdges;
intermediateEdges.reserve(1);
std::vector<DML_OUTPUT_GRAPH_EDGE_DESC> outputEdges;
outputEdges.reserve(1);
// Create an edge between the input tensor and the quantization operator
DML_INPUT_GRAPH_EDGE_DESC inputToQuantizeEdge{};
inputToQuantizeEdge.GraphInputIndex = 0;
inputToQuantizeEdge.ToNodeIndex = 1;
inputToQuantizeEdge.ToNodeInputIndex = 0;
inputEdges.push_back(inputToQuantizeEdge);
// Create an edge between the scale and the quantization operator
DML_INPUT_GRAPH_EDGE_DESC scaleToQuantizeEdge{};
scaleToQuantizeEdge.GraphInputIndex = 1;
scaleToQuantizeEdge.ToNodeIndex = 1;
scaleToQuantizeEdge.ToNodeInputIndex = 1;
inputEdges.push_back(scaleToQuantizeEdge);
// Create an edge between the generated zero point tensor and the quantization operator
DML_INTERMEDIATE_GRAPH_EDGE_DESC zeroPointToQuantizeEdge{};
zeroPointToQuantizeEdge.FromNodeIndex = 0;
zeroPointToQuantizeEdge.FromNodeOutputIndex = 0;
zeroPointToQuantizeEdge.ToNodeIndex = 1;
zeroPointToQuantizeEdge.ToNodeInputIndex = 2;
intermediateEdges.push_back(zeroPointToQuantizeEdge);
// Create an edge between the output of the quantization operator and the output of the graph
DML_OUTPUT_GRAPH_EDGE_DESC quantizeToOutputEdge{};
quantizeToOutputEdge.FromNodeIndex = 1;
quantizeToOutputEdge.FromNodeOutputIndex = 0;
quantizeToOutputEdge.GraphOutputIndex = 0;
outputEdges.push_back(quantizeToOutputEdge);
// Create the graph
MLOperatorGraphDesc operatorGraphDesc = {};
operatorGraphDesc.inputEdgeCount = gsl::narrow_cast<uint32_t>(inputEdges.size());
operatorGraphDesc.inputEdges = inputEdges.data();
operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast<uint32_t>(intermediateEdges.size());
operatorGraphDesc.intermediateEdges = intermediateEdges.data();
operatorGraphDesc.outputEdgeCount = gsl::narrow_cast<uint32_t>(outputEdges.size());
operatorGraphDesc.outputEdges = outputEdges.data();
operatorGraphDesc.nodeCount = gsl::narrow_cast<uint32_t>(opDescs.size());
operatorGraphDesc.nodes = opDescs.data();
SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelInfo);
}
}
};
class DmlOperatorElementwiseIf : public DmlOperator
{
public:
@ -777,18 +913,20 @@ DML_OP_DEFINE_CREATION_FUNCTION(Max, DmlOperatorElementwiseBinaryLo
DML_OP_DEFINE_CREATION_FUNCTION(Mean, DmlOperatorElementwiseMean);
// Operators with extra attributes:
DML_OP_DEFINE_CREATION_FUNCTION(Clip7, DmlOperatorElementwiseClip7);
DML_OP_DEFINE_CREATION_FUNCTION(Clip11, DmlOperatorElementwiseClip11);
DML_OP_DEFINE_CREATION_FUNCTION(Clip12, DmlOperatorElementwiseClip12);
DML_OP_DEFINE_CREATION_FUNCTION(Clip13, DmlOperatorElementwiseClip13);
DML_OP_DEFINE_CREATION_FUNCTION(Pow, DmlOperatorElementwisePow);
DML_OP_DEFINE_CREATION_FUNCTION(QuantizeLinear, DmlOperatorElementwiseQLinear<DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(DequantizeLinear, DmlOperatorElementwiseQLinear<DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Where, DmlOperatorElementwiseIf);
DML_OP_DEFINE_CREATION_FUNCTION(Mod, DmlOperatorElementwiseMod);
DML_OP_DEFINE_CREATION_FUNCTION(BitShift, DmlOperatorElementwiseBitShift);
DML_OP_DEFINE_CREATION_FUNCTION(IsInf, DmlOperatorElementwiseIsInf);
DML_OP_DEFINE_CREATION_FUNCTION(Round, DmlOperatorElementwiseRound);
DML_OP_DEFINE_CREATION_FUNCTION(Clip7, DmlOperatorElementwiseClip7);
DML_OP_DEFINE_CREATION_FUNCTION(Clip11, DmlOperatorElementwiseClip11);
DML_OP_DEFINE_CREATION_FUNCTION(Clip12, DmlOperatorElementwiseClip12);
DML_OP_DEFINE_CREATION_FUNCTION(Clip13, DmlOperatorElementwiseClip13);
DML_OP_DEFINE_CREATION_FUNCTION(Pow, DmlOperatorElementwisePow);
DML_OP_DEFINE_CREATION_FUNCTION(QuantizeLinear, DmlOperatorElementwiseQLinear<DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(DequantizeLinear, DmlOperatorElementwiseQLinear<DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(QuantizeLinear21, DmlOperatorQuantization21<DML_QUANTIZE_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(DequantizeLinear21, DmlOperatorQuantization21<DML_DEQUANTIZE_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Where, DmlOperatorElementwiseIf);
DML_OP_DEFINE_CREATION_FUNCTION(Mod, DmlOperatorElementwiseMod);
DML_OP_DEFINE_CREATION_FUNCTION(BitShift, DmlOperatorElementwiseBitShift);
DML_OP_DEFINE_CREATION_FUNCTION(IsInf, DmlOperatorElementwiseIsInf);
DML_OP_DEFINE_CREATION_FUNCTION(Round, DmlOperatorElementwiseRound);
// Fused operators:
DML_OP_DEFINE_CREATION_FUNCTION(DmlFusedAdd, DmlOperatorElementwiseBinary<DML_ELEMENT_WISE_ADD1_OPERATOR_DESC>);

View file

@ -225,39 +225,43 @@ ONNX_OPERATOR_KERNEL_EX(
namespace Dml
{
enum class SupportedTensorDataTypes : uint32_t
enum class SupportedTensorDataTypes : uint64_t
{
Undefined = 1<<0,
Float32 = 1<<1,
UInt8 = 1<<2,
Int8 = 1<<3,
UInt16 = 1<<4,
Int16 = 1<<5,
Int32 = 1<<6,
Int64 = 1<<7,
String = 1<<8,
Bool = 1<<9,
Float16 = 1<<10,
Float64 = 1<<11,
UInt32 = 1<<12,
UInt64 = 1<<13,
Complex64 = 1<<14,
Complex128 = 1<<15,
SequenceFloat32 = 1<<16,
SequenceUInt8 = 1<<17,
SequenceInt8 = 1<<18,
SequenceUInt16 = 1<<19,
SequenceInt16 = 1<<20,
SequenceInt32 = 1<<21,
SequenceInt64 = 1<<22,
SequenceString = 1<<23,
SequenceBool = 1<<24,
SequenceFloat16 = 1<<25,
SequenceFloat64 = 1<<26,
SequenceUInt32 = 1<<27,
SequenceUInt64 = 1<<28,
SequenceComplex64 = 1<<29,
SequenceComplex128 = 1<<30,
Undefined = 1LLU<<0,
Float32 = 1LLU<<1,
UInt4 = 1LLU<<2,
Int4 = 1LLU<<3,
UInt8 = 1LLU<<4,
Int8 = 1LLU<<5,
UInt16 = 1LLU<<6,
Int16 = 1LLU<<7,
Int32 = 1LLU<<8,
Int64 = 1LLU<<9,
String = 1LLU<<10,
Bool = 1LLU<<11,
Float16 = 1LLU<<12,
Float64 = 1LLU<<13,
UInt32 = 1LLU<<14,
UInt64 = 1LLU<<15,
Complex64 = 1LLU<<16,
Complex128 = 1LLU<<17,
SequenceFloat32 = 1LLU<<18,
SequenceUInt4 = 1LLU<<19,
SequenceInt4 = 1LLU<<20,
SequenceUInt8 = 1LLU<<21,
SequenceInt8 = 1LLU<<22,
SequenceUInt16 = 1LLU<<23,
SequenceInt16 = 1LLU<<24,
SequenceInt32 = 1LLU<<25,
SequenceInt64 = 1LLU<<26,
SequenceString = 1LLU<<27,
SequenceBool = 1LLU<<28,
SequenceFloat16 = 1LLU<<29,
SequenceFloat64 = 1LLU<<30,
SequenceUInt32 = 1LLU<<31,
SequenceUInt64 = 1LLU<<32,
SequenceComplex64 = 1LLU<<33,
SequenceComplex128 = 1LLU<<34,
Ints8to32 = UInt8|Int8|UInt16|Int16|UInt32|Int32,
Ints32to64 = UInt32|Int32|UInt64|Int64,
Ints8to64 = UInt8|Int8|UInt16|Int16|UInt32|Int32|UInt64|Int64,
@ -273,7 +277,7 @@ enum class SupportedTensorDataTypes : uint32_t
Ints16Bit = UInt16|Int16,
Ints32Bit = UInt32|Int32,
Ints64Bit = UInt64|Int64,
All = static_cast<uint32_t>(-1),
All = static_cast<uint64_t>(-1),
};
DEFINE_ENUM_FLAG_OPERATORS(Dml::SupportedTensorDataTypes);
@ -463,7 +467,9 @@ DML_OP_EXTERN_CREATION_FUNCTION(DmlFusedMatMul);
DML_OP_EXTERN_CREATION_FUNCTION(DmlFusedAdd);
DML_OP_EXTERN_CREATION_FUNCTION(DmlFusedSum);
DML_OP_EXTERN_CREATION_FUNCTION(QuantizeLinear);
DML_OP_EXTERN_CREATION_FUNCTION(QuantizeLinear21);
DML_OP_EXTERN_CREATION_FUNCTION(DequantizeLinear);
DML_OP_EXTERN_CREATION_FUNCTION(DequantizeLinear21);
DML_OP_EXTERN_CREATION_FUNCTION(QLinearSigmoid);
DML_OP_EXTERN_CREATION_FUNCTION(Sign);
DML_OP_EXTERN_CREATION_FUNCTION(IsNaN);
@ -599,8 +605,10 @@ constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListScatte
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListSlice10 = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListQuantizeLinear = { SupportedTensorDataTypes::Float32 | SupportedTensorDataTypes::Int32, SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListQuantizeLinear19 = { SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int32, SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListQuantizeLinear21 = { SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::UInt4 | SupportedTensorDataTypes::Int4 | SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 };
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListDequantizeLinear = { SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 | SupportedTensorDataTypes::Int32 };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListDequantizeLinear19 = { SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 | SupportedTensorDataTypes::Int32, SupportedTensorDataTypes::Float16to32 };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListDequantizeLinear21 = { SupportedTensorDataTypes::UInt4 | SupportedTensorDataTypes::Int4 | SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8, SupportedTensorDataTypes::Float16to32 };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListIsNan = { SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Bool };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListIsInf = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::Bool };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListConstantOfShape = { SupportedTensorDataTypes::Int64, SupportedTensorDataTypes::AllScalars };
@ -848,13 +856,16 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO_COPY( 7, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
{REG_INFO_COPY(11, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
{REG_INFO_COPY(13, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
{REG_INFO_COPY(21, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
{REG_INFO_COPY( 7, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
{REG_INFO_COPY(11, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
{REG_INFO_COPY(13, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
{REG_INFO_COPY(21, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
{REG_INFO_COPY( 7, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
{REG_INFO_COPY(13, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
{REG_INFO_COPY(14, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
{REG_INFO_COPY(19, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
{REG_INFO_COPY(21, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
// Elementwise
{REG_INFO( 7, Sqrt, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
@ -915,9 +926,11 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 10, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear, DmlGraphSupport::Supported)},
{REG_INFO( 13, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear, DmlGraphSupport::Supported)},
{REG_INFO( 19, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear19, DmlGraphSupport::Supported)},
{REG_INFO_VER( 21, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear21, DmlGraphSupport::Supported)},
{REG_INFO( 10, DequantizeLinear, typeNameListDefault, supportedTypeListDequantizeLinear, DmlGraphSupport::Supported)},
{REG_INFO( 13, DequantizeLinear, typeNameListDefault, supportedTypeListDequantizeLinear, DmlGraphSupport::Supported)},
{REG_INFO( 19, DequantizeLinear, typeNameListTwo, supportedTypeListDequantizeLinear19, DmlGraphSupport::Supported)},
{REG_INFO_VER( 21, DequantizeLinear, typeNameListTwo, supportedTypeListDequantizeLinear21, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear19, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, DequantizeLinear, typeNameListTwo, supportedTypeListDequantizeLinear19, DmlGraphSupport::Supported)},
{REG_INFO( 9, Sign, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported)},
@ -1071,6 +1084,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 9, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
{REG_INFO( 13, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
{REG_INFO( 19, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
{REG_INFO( 21, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
{REG_INFO_VER( 15, CastLike, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
{REG_INFO_VER( 19, CastLike, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
{REG_INFO( 7, MemcpyFromHost, typeNameListDefault, supportedTypeListAll)},
@ -1084,6 +1098,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 13, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)},
{REG_INFO( 15, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)},
{REG_INFO( 19, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)},
{REG_INFO( 21, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)},
{REG_INFO( 7, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)},
{REG_INFO( 13, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)},
{REG_INFO( 19, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)},
@ -1246,6 +1261,8 @@ void RegisterDmlOperators(IMLOperatorRegistry* registry)
// Scalars
if (bool(supportedTypes & SupportedTensorDataTypes::Float32)) edgeDescs.push_back(TensorEdgeDesc<float>());
if (bool(supportedTypes & SupportedTensorDataTypes::UInt4 )) edgeDescs.push_back(TensorEdgeDesc<::MLUInt4x2>());
if (bool(supportedTypes & SupportedTensorDataTypes::Int4 )) edgeDescs.push_back(TensorEdgeDesc<::MLInt4x2>());
if (bool(supportedTypes & SupportedTensorDataTypes::UInt8 )) edgeDescs.push_back(TensorEdgeDesc<uint8_t>());
if (bool(supportedTypes & SupportedTensorDataTypes::Int8 )) edgeDescs.push_back(TensorEdgeDesc<int8_t>());
if (bool(supportedTypes & SupportedTensorDataTypes::UInt16 )) edgeDescs.push_back(TensorEdgeDesc<uint16_t>());

View file

@ -5,6 +5,7 @@
#include "core/providers/dml/DmlExecutionProvider/inc/MLOperatorAuthor.h"
#include "MLOperatorAuthorPrivate.h"
#include "core/framework/int4.h"
#include <gsl/gsl>
#include <optional>
@ -20,6 +21,8 @@ namespace onnxruntime
}
using MLFloat16 = onnxruntime::MLFloat16;
using MLUInt4x2 = onnxruntime::Int4x2Base<false>;
using MLInt4x2 = onnxruntime::Int4x2Base<true>;
//
// Traits for numeric attribute types
@ -43,6 +46,18 @@ struct MLTypeTraits<int32_t>
static const MLOperatorTensorDataType TensorType = MLOperatorTensorDataType::Int32;
};
template <>
struct MLTypeTraits<onnxruntime::Int4x2Base<false>>
{
static const MLOperatorTensorDataType TensorType = MLOperatorTensorDataType::UInt4;
};
template <>
struct MLTypeTraits<onnxruntime::Int4x2Base<true>>
{
static const MLOperatorTensorDataType TensorType = MLOperatorTensorDataType::Int4;
};
template <>
struct MLTypeTraits<uint8_t>
{

View file

@ -1696,9 +1696,11 @@ using ShapeInferenceHelper_DepthToSpace = DepthToSpaceHelper;
using ShapeInferenceHelper_Squeeze7 = VersionedOpsetHelper<SqueezeHelper, 7>;
using ShapeInferenceHelper_Squeeze11 = VersionedOpsetHelper<SqueezeHelper, 11>;
using ShapeInferenceHelper_Squeeze13 = VersionedOpsetHelper<SqueezeHelper, 13>;
using ShapeInferenceHelper_Squeeze21 = VersionedOpsetHelper<SqueezeHelper, 21>;
using ShapeInferenceHelper_Unsqueeze7 = VersionedOpsetHelper<UnsqueezeHelper, 7>;
using ShapeInferenceHelper_Unsqueeze11 = VersionedOpsetHelper<UnsqueezeHelper, 11>;
using ShapeInferenceHelper_Unsqueeze13 = VersionedOpsetHelper<UnsqueezeHelper, 13>;
using ShapeInferenceHelper_Unsqueeze21 = VersionedOpsetHelper<UnsqueezeHelper, 21>;
using ShapeInferenceHelper_EyeLike = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Trilu = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Col2Im = Col2ImHelper;
@ -1708,6 +1710,7 @@ using ShapeInferenceHelper_Reshape7 = ReshapeHelper;
using ShapeInferenceHelper_Reshape13 = ReshapeHelper;
using ShapeInferenceHelper_Reshape14 = ReshapeHelper;
using ShapeInferenceHelper_Reshape19 = ReshapeHelper;
using ShapeInferenceHelper_Reshape21 = ReshapeHelper;
using ShapeInferenceHelper_ConstantOfShape = ConstantOfShapeHelper;
using ShapeInferenceHelper_Tile = TileHelper;
using ShapeInferenceHelper_Resize10 = VersionedOpsetHelper<ResizeHelper, 10>;
@ -1754,7 +1757,9 @@ using ShapeInferenceHelper_Asin = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Atan = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Affine = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_QuantizeLinear = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_QuantizeLinear21 = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_DequantizeLinear = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_DequantizeLinear21 = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_QLinearSigmoid = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_QAttention = QAttentionHelper;
using ShapeInferenceHelper_Attention = AttentionHelper;

View file

@ -437,6 +437,17 @@ namespace OperatorHelper
static const int sc_sinceVer_ReduceMin = 20;
}
namespace OnnxOperatorSet21
{
static const int sc_sinceVer_QuantizeLinear = 21;
static const int sc_sinceVer_DequantizeLinear = 21;
static const int sc_sinceVer_Squeeze = 21;
static const int sc_sinceVer_Unsqueeze = 21;
static const int sc_sinceVer_Reshape = 21;
static const int sc_sinceVer_Cast = 21;
static const int sc_sinceVer_Shape = 21;
}
namespace MsftOperatorSet1
{
static const int sc_sinceVer_DmlFusedConv = 1;

View file

@ -47,6 +47,19 @@ TEST(DequantizeLinearOpTest, Int4) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
}
// scalar scale with int4
TEST(DequantizeLinearOpTest, Int4NoZeroPoint) {
OpTester test("DequantizeLinear", 21);
std::vector<int64_t> dims{5};
constexpr int unused_val = 0;
// Odd number of int4 values to test packing/unpacking
test.AddInput<Int4x2>("x", dims, {Int4x2(-8, -3), Int4x2(1, 7), Int4x2(2, unused_val)});
test.AddInput<float>("x_scale", {}, {2.0f});
test.AddOutput<float>("y", dims, {-16.0f, -6.0f, 2.0f, 14.0f, 4.0f});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
}
// scalar zero & scale with uint4
TEST(DequantizeLinearOpTest, UInt4) {
OpTester test("DequantizeLinear", 21);
@ -61,6 +74,19 @@ TEST(DequantizeLinearOpTest, UInt4) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
}
// scalar scale with uint4
TEST(DequantizeLinearOpTest, UInt4NoZeroPoint) {
OpTester test("DequantizeLinear", 21);
std::vector<int64_t> dims{5};
constexpr int unused_val = 0;
// Odd number of uint4 values to test packing/unpacking
test.AddInput<UInt4x2>("x", dims, {UInt4x2(0, 1), UInt4x2(3, 15), UInt4x2(2, unused_val)});
test.AddInput<float>("x_scale", {}, {2.0f});
test.AddOutput<float>("y", dims, {0.0f, 2.0f, 6.0f, 30.0f, 4.0f});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
}
// Test int16 DequantizeLinear (per tensor)
TEST(DequantizeLinearOpTest, Int16) {
OpTester test("DequantizeLinear", 21);