mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
[DML] Add int4 QDQ (#21592)
This commit is contained in:
parent
12f426c63f
commit
de6ebcbb54
9 changed files with 307 additions and 62 deletions
|
|
@ -958,7 +958,8 @@ Do not modify directly.*
|
|||
|BitwiseNot|*in* X:**T**<br> *out* Y:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|BitwiseOr|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|BitwiseXor|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|Cast|*in* input:**T1**<br> *out* output:**T2**|19+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|Cast|*in* input:**T1**<br> *out* output:**T2**|21+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||19+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||13+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||9+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||6+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|
|
@ -992,7 +993,8 @@ Do not modify directly.*
|
|||
|DepthToSpace|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|DequantizeLinear|*in* x:**T**<br> *in* x_scale:**tensor(float)**<br> *in* x_zero_point:**T**<br> *out* y:**tensor(float)**<br><br>or<br><br>*in* x:**T1**<br> *in* x_scale:**T2**<br> *in* x_zero_point:**T1**<br> *out* y:**T2**|19+|**T1** = tensor(int32), tensor(int8), tensor(uint8)<br/> **T2** = tensor(float), tensor(float16)|
|
||||
|DequantizeLinear|*in* x:**T**<br> *in* x_scale:**tensor(float)**<br> *in* x_zero_point:**T**<br> *out* y:**tensor(float)**<br><br>or<br><br>*in* x:**T1**<br> *in* x_scale:**T2**<br> *in* x_zero_point:**T1**<br> *out* y:**T2**|21+|**T1** = tensor(int4), tensor(int8), tensor(uint4), tensor(uint8)<br/> **T2** = tensor(float), tensor(float16)|
|
||||
|||19+|**T1** = tensor(int32), tensor(int8), tensor(uint8)<br/> **T2** = tensor(float), tensor(float16)|
|
||||
|||13+|**T** = tensor(int32), tensor(int8), tensor(uint8)|
|
||||
|||10+|**T** = tensor(int32), tensor(int8), tensor(uint8)|
|
||||
|Div|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|14+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|
|
@ -1107,8 +1109,8 @@ Do not modify directly.*
|
|||
|MeanVarianceNormalization|*in* X:**T**<br> *out* Y:**T**<br><br>or<br><br>*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(float), tensor(float16)|
|
||||
|||9+|**T** = tensor(float), tensor(float16)|
|
||||
|||1+|**T** = tensor(float), tensor(float16)|
|
||||
|MemcpyFromHost|*in* X:**T**<br> *out* Y:**T**|1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|MemcpyToHost|*in* X:**T**<br> *out* Y:**T**|1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|MemcpyFromHost|*in* X:**T**<br> *out* Y:**T**|1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
|
||||
|MemcpyToHost|*in* X:**T**<br> *out* Y:**T**|1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
|
||||
|Min|*in* data_0:**T**<br> *out* min:**T**|13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||12+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||8+|**T** = tensor(float), tensor(float16)|
|
||||
|
|
@ -1145,7 +1147,8 @@ Do not modify directly.*
|
|||
|||7+|**T** = tensor(float), tensor(float16)|
|
||||
|QLinearConv|*in* x:**T1**<br> *in* x_scale:**tensor(float)**<br> *in* x_zero_point:**T1**<br> *in* w:**T2**<br> *in* w_scale:**tensor(float)**<br> *in* w_zero_point:**T2**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T3**<br> *in* B:**T4**<br> *out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(int8), tensor(uint8)<br/> **T4** = tensor(int32)|
|
||||
|QLinearMatMul|*in* a:**T1**<br> *in* a_scale:**TS**<br> *in* a_zero_point:**T1**<br> *in* b:**T2**<br> *in* b_scale:**TS**<br> *in* b_zero_point:**T2**<br> *in* y_scale:**TS**<br> *in* y_zero_point:**T3**<br> *out* y:**T3**<br><br>or<br><br>*in* a:**T1**<br> *in* a_scale:**tensor(float)**<br> *in* a_zero_point:**T1**<br> *in* b:**T2**<br> *in* b_scale:**tensor(float)**<br> *in* b_zero_point:**T2**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T3**<br> *out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(int8), tensor(uint8)|
|
||||
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**<br><br>or<br><br>*in* x:**T1**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|19+|**T1** = tensor(float), tensor(float16), tensor(int32)<br/> **T2** = tensor(int8), tensor(uint8)|
|
||||
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**<br><br>or<br><br>*in* x:**T1**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|21+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(int4), tensor(int8), tensor(uint4), tensor(uint8)|
|
||||
|||19+|**T1** = tensor(float), tensor(float16), tensor(int32)<br/> **T2** = tensor(int8), tensor(uint8)|
|
||||
|||13+|**T1** = tensor(float), tensor(int32)<br/> **T2** = tensor(int8), tensor(uint8)|
|
||||
|||10+|**T1** = tensor(float), tensor(int32)<br/> **T2** = tensor(int8), tensor(uint8)|
|
||||
|RNN|*in* X:**T**<br> *in* W:**T**<br> *in* R:**T**<br> *in* B:**T**<br> *in* sequence_lens:**T1**<br> *in* initial_h:**T**<br> *out* Y:**T**<br> *out* Y_h:**T**|14+|**T** = tensor(float), tensor(float16)|
|
||||
|
|
@ -1199,7 +1202,8 @@ Do not modify directly.*
|
|||
|Relu|*in* X:**T**<br> *out* Y:**T**|14+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8)|
|
||||
|||13+|**T** = tensor(float), tensor(float16)|
|
||||
|||6+|**T** = tensor(float), tensor(float16)|
|
||||
|Reshape|*in* data:**T**<br> *in* shape:**tensor(int64)**<br> *out* reshaped:**T**<br><br>or<br><br>*in* data:**T**<br> *out* reshaped:**T**|19+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|Reshape|*in* data:**T**<br> *in* shape:**tensor(int64)**<br> *out* reshaped:**T**<br><br>or<br><br>*in* data:**T**<br> *out* reshaped:**T**|21+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||19+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||14+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||5+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|
|
@ -1230,10 +1234,11 @@ Do not modify directly.*
|
|||
|SequenceErase|*in* input_sequence:**S**<br> *in* position:**I**<br> *out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)<br/> **S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
|
||||
|SequenceInsert|*in* input_sequence:**S**<br> *in* tensor:**T**<br> *in* position:**I**<br> *out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)<br/> **S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
|
||||
|SequenceLength|*in* input_sequence:**S**<br> *out* length:**I**|11+|**I** = tensor(int64)<br/> **S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
|
||||
|Shape|*in* data:**T**<br> *out* shape:**T1**|19+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|
||||
|||15+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|
||||
|||13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|
||||
|||1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|
||||
|Shape|*in* data:**T**<br> *out* shape:**T1**|21+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|
||||
|||19+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|
||||
|||15+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|
||||
|||13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|
||||
|||1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|
||||
|Shrink|*in* input:**T**<br> *out* output:**T**|9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint8)|
|
||||
|Sigmoid|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(float), tensor(float16)|
|
||||
|||6+|**T** = tensor(float), tensor(float16)|
|
||||
|
|
@ -1242,9 +1247,9 @@ Do not modify directly.*
|
|||
|SimplifiedLayerNormalization|*in* X:**T**<br> *in* scale:**V**<br> *out* Y:**V**<br> *out* inv_std_var:**U**|1+|**T** = tensor(float), tensor(float16)<br/> **U** = tensor(float), tensor(float16)<br/> **V** = tensor(float), tensor(float16)|
|
||||
|Sin|*in* input:**T**<br> *out* output:**T**|7+|**T** = tensor(float), tensor(float16)|
|
||||
|Sinh|*in* input:**T**<br> *out* output:**T**|9+|**T** = tensor(float), tensor(float16)|
|
||||
|Size|*in* data:**T**<br> *out* size:**T1**|19+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|
||||
|||13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|
||||
|||1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|
||||
|Size|*in* data:**T**<br> *out* size:**T1**|19+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|
||||
|||13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|
||||
|||1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|
||||
|Slice|*in* data:**T**<br> *in* starts:**Tind**<br> *in* ends:**Tind**<br> *in* axes:**Tind**<br> *in* steps:**Tind**<br> *out* output:**T**<br><br>or<br><br>*in* data:**T**<br> *out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|
||||
|||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|
||||
|||10+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|
||||
|
|
@ -1262,7 +1267,8 @@ Do not modify directly.*
|
|||
|||2+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|Sqrt|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(float), tensor(float16)|
|
||||
|||6+|**T** = tensor(float), tensor(float16)|
|
||||
|Squeeze|*in* data:**T**<br> *in* axes:**tensor(int64)**<br> *out* squeezed:**T**<br><br>or<br><br>*in* data:**T**<br> *out* squeezed:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|Squeeze|*in* data:**T**<br> *in* axes:**tensor(int64)**<br> *out* squeezed:**T**<br><br>or<br><br>*in* data:**T**<br> *out* squeezed:**T**|21+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|Sub|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|14+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|
|
@ -1284,7 +1290,8 @@ Do not modify directly.*
|
|||
|Transpose|*in* data:**T**<br> *out* transposed:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|Trilu|*in* input:**T**<br> *in* k:**tensor(int64)**<br> *out* output:**T**|14+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|Unsqueeze|*in* data:**T**<br> *in* axes:**tensor(int64)**<br> *out* expanded:**T**<br><br>or<br><br>*in* data:**T**<br> *out* expanded:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|Unsqueeze|*in* data:**T**<br> *in* axes:**tensor(int64)**<br> *out* expanded:**T**<br><br>or<br><br>*in* data:**T**<br> *out* expanded:**T**|21+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|Upsample|*in* X:**T**<br> *in* scales:**tensor(float)**<br> *out* Y:**T**<br><br>or<br><br>*in* X:**T**<br> *out* Y:**T**|10+|**T** = tensor(float), tensor(float16)|
|
||||
|
|
|
|||
|
|
@ -121,7 +121,7 @@ uint32_t GetSupportedDeviceDataTypeMask(IDMLDevice* dmlDevice)
|
|||
uint32_t deviceTypeMask = 0u;
|
||||
|
||||
// Form the bitmask of all supported data types.
|
||||
for (uint32_t i = 0; i <= DML_TENSOR_DATA_TYPE_INT64; ++i)
|
||||
for (uint32_t i = 0; i <= DML_TENSOR_DATA_TYPE_INT4; ++i)
|
||||
{
|
||||
DML_FEATURE_QUERY_TENSOR_DATA_TYPE_SUPPORT dataTypeQuery = { static_cast<DML_TENSOR_DATA_TYPE>(i) };
|
||||
DML_FEATURE_DATA_TENSOR_DATA_TYPE_SUPPORT dataTypeSupport = {};
|
||||
|
|
|
|||
|
|
@ -247,6 +247,8 @@ namespace Windows::AI::MachineLearning::Adapter
|
|||
}
|
||||
|
||||
ML_TENSOR_TYPE_CASE(float);
|
||||
ML_TENSOR_TYPE_CASE(onnxruntime::Int4x2Base<false>);
|
||||
ML_TENSOR_TYPE_CASE(onnxruntime::Int4x2Base<true>);
|
||||
ML_TENSOR_TYPE_CASE(uint8_t);
|
||||
ML_TENSOR_TYPE_CASE(int8_t);
|
||||
ML_TENSOR_TYPE_CASE(uint16_t);
|
||||
|
|
@ -293,6 +295,8 @@ namespace Windows::AI::MachineLearning::Adapter
|
|||
return onnxruntime::DataTypeImpl::GetTensorType<std::string>();
|
||||
|
||||
ML_TENSOR_TYPE_CASE(float);
|
||||
ML_TENSOR_TYPE_CASE(onnxruntime::Int4x2Base<false>);
|
||||
ML_TENSOR_TYPE_CASE(onnxruntime::Int4x2Base<true>);
|
||||
ML_TENSOR_TYPE_CASE(uint8_t);
|
||||
ML_TENSOR_TYPE_CASE(int8_t);
|
||||
ML_TENSOR_TYPE_CASE(uint16_t);
|
||||
|
|
@ -314,6 +318,8 @@ namespace Windows::AI::MachineLearning::Adapter
|
|||
return onnxruntime::DataTypeImpl::GetSequenceTensorType<std::string>();
|
||||
|
||||
ML_SEQUENCE_TENSOR_TYPE_CASE(float);
|
||||
ML_SEQUENCE_TENSOR_TYPE_CASE(onnxruntime::Int4x2Base<false>);
|
||||
ML_SEQUENCE_TENSOR_TYPE_CASE(onnxruntime::Int4x2Base<true>);
|
||||
ML_SEQUENCE_TENSOR_TYPE_CASE(uint8_t);
|
||||
ML_SEQUENCE_TENSOR_TYPE_CASE(int8_t);
|
||||
ML_SEQUENCE_TENSOR_TYPE_CASE(uint16_t);
|
||||
|
|
@ -335,6 +341,8 @@ namespace Windows::AI::MachineLearning::Adapter
|
|||
return onnxruntime::DataTypeImpl::GetType<std::string>();
|
||||
|
||||
ML_PRIMITIVE_TYPE_CASE(float);
|
||||
ML_PRIMITIVE_TYPE_CASE(onnxruntime::Int4x2Base<false>);
|
||||
ML_PRIMITIVE_TYPE_CASE(onnxruntime::Int4x2Base<true>);
|
||||
ML_PRIMITIVE_TYPE_CASE(uint8_t);
|
||||
ML_PRIMITIVE_TYPE_CASE(int8_t);
|
||||
ML_PRIMITIVE_TYPE_CASE(uint16_t);
|
||||
|
|
@ -364,6 +372,12 @@ namespace Windows::AI::MachineLearning::Adapter
|
|||
case onnx::TensorProto_DataType_FLOAT:
|
||||
return MLOperatorTensorDataType::Float;
|
||||
|
||||
case onnx::TensorProto_DataType_UINT4:
|
||||
return MLOperatorTensorDataType::UInt4;
|
||||
|
||||
case onnx::TensorProto_DataType_INT4:
|
||||
return MLOperatorTensorDataType::Int4;
|
||||
|
||||
case onnx::TensorProto_DataType_UINT8:
|
||||
return MLOperatorTensorDataType::UInt8;
|
||||
|
||||
|
|
@ -455,6 +469,12 @@ namespace Windows::AI::MachineLearning::Adapter
|
|||
case MLOperatorTensorDataType::Float:
|
||||
return "tensor(float)";
|
||||
|
||||
case MLOperatorTensorDataType::UInt4:
|
||||
return "tensor(uint4)";
|
||||
|
||||
case MLOperatorTensorDataType::Int4:
|
||||
return "tensor(int4)";
|
||||
|
||||
case MLOperatorTensorDataType::UInt8:
|
||||
return "tensor(uint8)";
|
||||
|
||||
|
|
@ -509,6 +529,12 @@ namespace Windows::AI::MachineLearning::Adapter
|
|||
case MLOperatorTensorDataType::Float:
|
||||
return "seq(tensor(float))";
|
||||
|
||||
case MLOperatorTensorDataType::UInt4:
|
||||
return "seq(tensor(uint4))";
|
||||
|
||||
case MLOperatorTensorDataType::Int4:
|
||||
return "seq(tensor(int4))";
|
||||
|
||||
case MLOperatorTensorDataType::UInt8:
|
||||
return "seq(tensor(uint8))";
|
||||
|
||||
|
|
@ -1518,7 +1544,7 @@ namespace Windows::AI::MachineLearning::Adapter
|
|||
AbstractOperatorDesc abstractDesc = SchemaHelpers::ConvertOperatorDesc(*node);
|
||||
m_graphNodeCreateInfo->nodes.push_back(std::make_unique<AbstractOperatorDesc>(std::move(abstractDesc)));
|
||||
}
|
||||
|
||||
|
||||
// There can be operators (or kernels) which don't require any input.
|
||||
assert(operatorGraphDesc->inputEdgeCount == 0 || operatorGraphDesc->inputEdges != nullptr);
|
||||
m_graphNodeCreateInfo->inputEdges.insert(
|
||||
|
|
|
|||
|
|
@ -528,6 +528,7 @@ public:
|
|||
std::vector<uint32_t> outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0);
|
||||
const uint32_t outputShapeDimCount = gsl::narrow_cast<uint32_t>(outputShape.size());
|
||||
const DML_TENSOR_DATA_TYPE inputDataType = m_inputTensorDescs[0].GetDmlDataType();
|
||||
const DML_TENSOR_DATA_TYPE outputDataType = m_outputTensorDescs[0].GetDmlDataType();
|
||||
bool hasZeroPointTensor = kernelInfo.IsInputValid(2);
|
||||
|
||||
uint32_t axis = 0;
|
||||
|
|
@ -601,6 +602,141 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
template <typename TOperatorDesc>
|
||||
class DmlOperatorQuantization21 : public DmlOperator
|
||||
{
|
||||
public:
|
||||
DmlOperatorQuantization21(const MLOperatorKernelCreationContext& kernelInfo) : DmlOperator(kernelInfo)
|
||||
{
|
||||
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 2 || kernelInfo.GetInputCount() == 3);
|
||||
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);
|
||||
|
||||
Initialize(kernelInfo, std::nullopt, std::nullopt);
|
||||
|
||||
std::vector<uint32_t> outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0);
|
||||
const uint32_t outputShapeDimCount = gsl::narrow_cast<uint32_t>(outputShape.size());
|
||||
const DML_TENSOR_DATA_TYPE inputDataType = m_inputTensorDescs[0].GetDmlDataType();
|
||||
const DML_TENSOR_DATA_TYPE outputDataType = m_outputTensorDescs[0].GetDmlDataType();
|
||||
bool hasZeroPointTensor = kernelInfo.IsInputValid(2);
|
||||
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
|
||||
std::vector<DML_TENSOR_DESC> quantizationTensors;
|
||||
quantizationTensors.push_back(inputDescs[1]);
|
||||
|
||||
const bool isSignedQuantization =
|
||||
inputDataType == DML_TENSOR_DATA_TYPE_INT4 ||
|
||||
outputDataType == DML_TENSOR_DATA_TYPE_INT4 ||
|
||||
inputDataType == DML_TENSOR_DATA_TYPE_INT8 ||
|
||||
outputDataType == DML_TENSOR_DATA_TYPE_INT8;
|
||||
|
||||
if (hasZeroPointTensor || isSignedQuantization)
|
||||
{
|
||||
if (hasZeroPointTensor)
|
||||
{
|
||||
quantizationTensors.push_back(inputDescs[2]);
|
||||
}
|
||||
|
||||
TOperatorDesc opDesc = {};
|
||||
opDesc.InputTensor = &inputDescs[0];
|
||||
opDesc.QuantizationType = hasZeroPointTensor ? DML_QUANTIZATION_TYPE_SCALE_ZERO_POINT : DML_QUANTIZATION_TYPE_SCALE;
|
||||
opDesc.QuantizationTensorCount = static_cast<uint32_t>(quantizationTensors.size());
|
||||
opDesc.QuantizationTensors = quantizationTensors.data();
|
||||
opDesc.OutputTensor = &outputDescs[0];
|
||||
SetDmlOperatorDesc({ApiTraits::OperatorDescTraits<TOperatorDesc>::Type, &opDesc}, kernelInfo);
|
||||
}
|
||||
else
|
||||
{
|
||||
// For unsigned quantization, DML uses the midpoint of the datatype when zero point isn't provided (e.g. 8 for uint4 and 128 for uint8)
|
||||
// since this is the most sane default in theory that represents the midpoint of the quantization. This is also the default used by various
|
||||
// other frameworks and quantization tools. But because ONNX uses a default zero point of 0 no matter if it's signed or unsigned, we need
|
||||
// to generate a constant zero point tensor with a value of 0 to override dml's default value when it isn't provided.
|
||||
auto zeroPointDataType = std::is_same_v<TOperatorDesc, DML_QUANTIZE_OPERATOR_DESC> ? outputDataType : inputDataType;
|
||||
|
||||
// DML doesn't support int4 FILL_VALUE_CONSTANT yet, so simply create an int8 scalar and reinterpret it to an int4 tensor
|
||||
auto zeroPointInt8DataType = zeroPointDataType == DML_TENSOR_DATA_TYPE_INT4 ? DML_TENSOR_DATA_TYPE_INT8 : DML_TENSOR_DATA_TYPE_UINT8;
|
||||
|
||||
TensorDesc scalarTensorDesc(zeroPointInt8DataType, std::vector<uint32_t>(m_inputTensorDescs[1].GetDimensionCount(), 1));
|
||||
DML_TENSOR_DESC scalarDmlTensorDesc = scalarTensorDesc.GetDmlDesc();
|
||||
|
||||
// Create a tensor full of zeros
|
||||
DML_FILL_VALUE_CONSTANT_OPERATOR_DESC zeroPointConstantDesc = {};
|
||||
zeroPointConstantDesc.ValueDataType = zeroPointInt8DataType;
|
||||
zeroPointConstantDesc.OutputTensor = &scalarDmlTensorDesc;
|
||||
DML_OPERATOR_DESC zeroPointConstantDmlDesc = { DML_OPERATOR_FILL_VALUE_CONSTANT, &zeroPointConstantDesc };
|
||||
|
||||
// Broadcast the zero point tensor to match the scale tensor
|
||||
TensorDesc broadcastedScalarTensorDesc(zeroPointDataType, m_inputTensorDescs[1].GetSizes(), std::vector<uint32_t>(m_inputTensorDescs[1].GetDimensionCount(), 0));
|
||||
quantizationTensors.push_back(broadcastedScalarTensorDesc.GetDmlDesc());
|
||||
|
||||
// Create the quantize/dequantize operator
|
||||
TOperatorDesc quantizationOpDesc = {};
|
||||
quantizationOpDesc.InputTensor = &inputDescs[0];
|
||||
quantizationOpDesc.QuantizationType = DML_QUANTIZATION_TYPE_SCALE_ZERO_POINT;
|
||||
quantizationOpDesc.QuantizationTensorCount = static_cast<uint32_t>(quantizationTensors.size());
|
||||
quantizationOpDesc.QuantizationTensors = quantizationTensors.data();
|
||||
quantizationOpDesc.OutputTensor = &outputDescs[0];
|
||||
DML_OPERATOR_DESC quantizationOpDmlDesc = { ApiTraits::OperatorDescTraits<TOperatorDesc>::Type, &quantizationOpDesc };
|
||||
|
||||
std::array<const DML_OPERATOR_DESC*, 2> opDescs = {
|
||||
&zeroPointConstantDmlDesc,
|
||||
&quantizationOpDmlDesc,
|
||||
};
|
||||
|
||||
std::vector<DML_INPUT_GRAPH_EDGE_DESC> inputEdges;
|
||||
inputEdges.reserve(2);
|
||||
|
||||
std::vector<DML_INTERMEDIATE_GRAPH_EDGE_DESC> intermediateEdges;
|
||||
intermediateEdges.reserve(1);
|
||||
|
||||
std::vector<DML_OUTPUT_GRAPH_EDGE_DESC> outputEdges;
|
||||
outputEdges.reserve(1);
|
||||
|
||||
// Create an edge between the input tensor and the quantization operator
|
||||
DML_INPUT_GRAPH_EDGE_DESC inputToQuantizeEdge{};
|
||||
inputToQuantizeEdge.GraphInputIndex = 0;
|
||||
inputToQuantizeEdge.ToNodeIndex = 1;
|
||||
inputToQuantizeEdge.ToNodeInputIndex = 0;
|
||||
inputEdges.push_back(inputToQuantizeEdge);
|
||||
|
||||
// Create an edge between the scale and the quantization operator
|
||||
DML_INPUT_GRAPH_EDGE_DESC scaleToQuantizeEdge{};
|
||||
scaleToQuantizeEdge.GraphInputIndex = 1;
|
||||
scaleToQuantizeEdge.ToNodeIndex = 1;
|
||||
scaleToQuantizeEdge.ToNodeInputIndex = 1;
|
||||
inputEdges.push_back(scaleToQuantizeEdge);
|
||||
|
||||
// Create an edge between the generated zero point tensor and the quantization operator
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC zeroPointToQuantizeEdge{};
|
||||
zeroPointToQuantizeEdge.FromNodeIndex = 0;
|
||||
zeroPointToQuantizeEdge.FromNodeOutputIndex = 0;
|
||||
zeroPointToQuantizeEdge.ToNodeIndex = 1;
|
||||
zeroPointToQuantizeEdge.ToNodeInputIndex = 2;
|
||||
intermediateEdges.push_back(zeroPointToQuantizeEdge);
|
||||
|
||||
// Create an edge between the output of the quantization operator and the output of the graph
|
||||
DML_OUTPUT_GRAPH_EDGE_DESC quantizeToOutputEdge{};
|
||||
quantizeToOutputEdge.FromNodeIndex = 1;
|
||||
quantizeToOutputEdge.FromNodeOutputIndex = 0;
|
||||
quantizeToOutputEdge.GraphOutputIndex = 0;
|
||||
outputEdges.push_back(quantizeToOutputEdge);
|
||||
|
||||
// Create the graph
|
||||
MLOperatorGraphDesc operatorGraphDesc = {};
|
||||
operatorGraphDesc.inputEdgeCount = gsl::narrow_cast<uint32_t>(inputEdges.size());
|
||||
operatorGraphDesc.inputEdges = inputEdges.data();
|
||||
operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast<uint32_t>(intermediateEdges.size());
|
||||
operatorGraphDesc.intermediateEdges = intermediateEdges.data();
|
||||
operatorGraphDesc.outputEdgeCount = gsl::narrow_cast<uint32_t>(outputEdges.size());
|
||||
operatorGraphDesc.outputEdges = outputEdges.data();
|
||||
operatorGraphDesc.nodeCount = gsl::narrow_cast<uint32_t>(opDescs.size());
|
||||
operatorGraphDesc.nodes = opDescs.data();
|
||||
SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelInfo);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class DmlOperatorElementwiseIf : public DmlOperator
|
||||
{
|
||||
public:
|
||||
|
|
@ -777,18 +913,20 @@ DML_OP_DEFINE_CREATION_FUNCTION(Max, DmlOperatorElementwiseBinaryLo
|
|||
DML_OP_DEFINE_CREATION_FUNCTION(Mean, DmlOperatorElementwiseMean);
|
||||
|
||||
// Operators with extra attributes:
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Clip7, DmlOperatorElementwiseClip7);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Clip11, DmlOperatorElementwiseClip11);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Clip12, DmlOperatorElementwiseClip12);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Clip13, DmlOperatorElementwiseClip13);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Pow, DmlOperatorElementwisePow);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(QuantizeLinear, DmlOperatorElementwiseQLinear<DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_DESC>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(DequantizeLinear, DmlOperatorElementwiseQLinear<DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_DESC>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Where, DmlOperatorElementwiseIf);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Mod, DmlOperatorElementwiseMod);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(BitShift, DmlOperatorElementwiseBitShift);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(IsInf, DmlOperatorElementwiseIsInf);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Round, DmlOperatorElementwiseRound);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Clip7, DmlOperatorElementwiseClip7);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Clip11, DmlOperatorElementwiseClip11);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Clip12, DmlOperatorElementwiseClip12);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Clip13, DmlOperatorElementwiseClip13);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Pow, DmlOperatorElementwisePow);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(QuantizeLinear, DmlOperatorElementwiseQLinear<DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_DESC>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(DequantizeLinear, DmlOperatorElementwiseQLinear<DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_DESC>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(QuantizeLinear21, DmlOperatorQuantization21<DML_QUANTIZE_OPERATOR_DESC>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(DequantizeLinear21, DmlOperatorQuantization21<DML_DEQUANTIZE_OPERATOR_DESC>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Where, DmlOperatorElementwiseIf);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Mod, DmlOperatorElementwiseMod);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(BitShift, DmlOperatorElementwiseBitShift);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(IsInf, DmlOperatorElementwiseIsInf);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Round, DmlOperatorElementwiseRound);
|
||||
|
||||
// Fused operators:
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(DmlFusedAdd, DmlOperatorElementwiseBinary<DML_ELEMENT_WISE_ADD1_OPERATOR_DESC>);
|
||||
|
|
|
|||
|
|
@ -225,39 +225,43 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
namespace Dml
|
||||
{
|
||||
|
||||
enum class SupportedTensorDataTypes : uint32_t
|
||||
enum class SupportedTensorDataTypes : uint64_t
|
||||
{
|
||||
Undefined = 1<<0,
|
||||
Float32 = 1<<1,
|
||||
UInt8 = 1<<2,
|
||||
Int8 = 1<<3,
|
||||
UInt16 = 1<<4,
|
||||
Int16 = 1<<5,
|
||||
Int32 = 1<<6,
|
||||
Int64 = 1<<7,
|
||||
String = 1<<8,
|
||||
Bool = 1<<9,
|
||||
Float16 = 1<<10,
|
||||
Float64 = 1<<11,
|
||||
UInt32 = 1<<12,
|
||||
UInt64 = 1<<13,
|
||||
Complex64 = 1<<14,
|
||||
Complex128 = 1<<15,
|
||||
SequenceFloat32 = 1<<16,
|
||||
SequenceUInt8 = 1<<17,
|
||||
SequenceInt8 = 1<<18,
|
||||
SequenceUInt16 = 1<<19,
|
||||
SequenceInt16 = 1<<20,
|
||||
SequenceInt32 = 1<<21,
|
||||
SequenceInt64 = 1<<22,
|
||||
SequenceString = 1<<23,
|
||||
SequenceBool = 1<<24,
|
||||
SequenceFloat16 = 1<<25,
|
||||
SequenceFloat64 = 1<<26,
|
||||
SequenceUInt32 = 1<<27,
|
||||
SequenceUInt64 = 1<<28,
|
||||
SequenceComplex64 = 1<<29,
|
||||
SequenceComplex128 = 1<<30,
|
||||
Undefined = 1LLU<<0,
|
||||
Float32 = 1LLU<<1,
|
||||
UInt4 = 1LLU<<2,
|
||||
Int4 = 1LLU<<3,
|
||||
UInt8 = 1LLU<<4,
|
||||
Int8 = 1LLU<<5,
|
||||
UInt16 = 1LLU<<6,
|
||||
Int16 = 1LLU<<7,
|
||||
Int32 = 1LLU<<8,
|
||||
Int64 = 1LLU<<9,
|
||||
String = 1LLU<<10,
|
||||
Bool = 1LLU<<11,
|
||||
Float16 = 1LLU<<12,
|
||||
Float64 = 1LLU<<13,
|
||||
UInt32 = 1LLU<<14,
|
||||
UInt64 = 1LLU<<15,
|
||||
Complex64 = 1LLU<<16,
|
||||
Complex128 = 1LLU<<17,
|
||||
SequenceFloat32 = 1LLU<<18,
|
||||
SequenceUInt4 = 1LLU<<19,
|
||||
SequenceInt4 = 1LLU<<20,
|
||||
SequenceUInt8 = 1LLU<<21,
|
||||
SequenceInt8 = 1LLU<<22,
|
||||
SequenceUInt16 = 1LLU<<23,
|
||||
SequenceInt16 = 1LLU<<24,
|
||||
SequenceInt32 = 1LLU<<25,
|
||||
SequenceInt64 = 1LLU<<26,
|
||||
SequenceString = 1LLU<<27,
|
||||
SequenceBool = 1LLU<<28,
|
||||
SequenceFloat16 = 1LLU<<29,
|
||||
SequenceFloat64 = 1LLU<<30,
|
||||
SequenceUInt32 = 1LLU<<31,
|
||||
SequenceUInt64 = 1LLU<<32,
|
||||
SequenceComplex64 = 1LLU<<33,
|
||||
SequenceComplex128 = 1LLU<<34,
|
||||
Ints8to32 = UInt8|Int8|UInt16|Int16|UInt32|Int32,
|
||||
Ints32to64 = UInt32|Int32|UInt64|Int64,
|
||||
Ints8to64 = UInt8|Int8|UInt16|Int16|UInt32|Int32|UInt64|Int64,
|
||||
|
|
@ -273,7 +277,7 @@ enum class SupportedTensorDataTypes : uint32_t
|
|||
Ints16Bit = UInt16|Int16,
|
||||
Ints32Bit = UInt32|Int32,
|
||||
Ints64Bit = UInt64|Int64,
|
||||
All = static_cast<uint32_t>(-1),
|
||||
All = static_cast<uint64_t>(-1),
|
||||
};
|
||||
DEFINE_ENUM_FLAG_OPERATORS(Dml::SupportedTensorDataTypes);
|
||||
|
||||
|
|
@ -463,7 +467,9 @@ DML_OP_EXTERN_CREATION_FUNCTION(DmlFusedMatMul);
|
|||
DML_OP_EXTERN_CREATION_FUNCTION(DmlFusedAdd);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(DmlFusedSum);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(QuantizeLinear);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(QuantizeLinear21);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(DequantizeLinear);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(DequantizeLinear21);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(QLinearSigmoid);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Sign);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(IsNaN);
|
||||
|
|
@ -599,8 +605,10 @@ constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListScatte
|
|||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListSlice10 = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListQuantizeLinear = { SupportedTensorDataTypes::Float32 | SupportedTensorDataTypes::Int32, SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListQuantizeLinear19 = { SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int32, SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListQuantizeLinear21 = { SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::UInt4 | SupportedTensorDataTypes::Int4 | SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListDequantizeLinear = { SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 | SupportedTensorDataTypes::Int32 };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListDequantizeLinear19 = { SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 | SupportedTensorDataTypes::Int32, SupportedTensorDataTypes::Float16to32 };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListDequantizeLinear21 = { SupportedTensorDataTypes::UInt4 | SupportedTensorDataTypes::Int4 | SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8, SupportedTensorDataTypes::Float16to32 };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListIsNan = { SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Bool };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListIsInf = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::Bool };
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListConstantOfShape = { SupportedTensorDataTypes::Int64, SupportedTensorDataTypes::AllScalars };
|
||||
|
|
@ -848,13 +856,16 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
|
|||
{REG_INFO_COPY( 7, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_COPY(11, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_COPY(13, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO_COPY(21, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO_COPY( 7, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_COPY(11, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_COPY(13, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO_COPY(21, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO_COPY( 7, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO_COPY(13, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO_COPY(14, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO_COPY(19, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO_COPY(21, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
|
||||
|
||||
// Elementwise
|
||||
{REG_INFO( 7, Sqrt, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
|
|
@ -915,9 +926,11 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
|
|||
{REG_INFO( 10, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 13, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 19, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear19, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 21, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear21, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 10, DequantizeLinear, typeNameListDefault, supportedTypeListDequantizeLinear, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 13, DequantizeLinear, typeNameListDefault, supportedTypeListDequantizeLinear, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 19, DequantizeLinear, typeNameListTwo, supportedTypeListDequantizeLinear19, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 21, DequantizeLinear, typeNameListTwo, supportedTypeListDequantizeLinear21, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MS( 1, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear19, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MS( 1, DequantizeLinear, typeNameListTwo, supportedTypeListDequantizeLinear19, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Sign, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported)},
|
||||
|
|
@ -1071,6 +1084,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
|
|||
{REG_INFO( 9, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 13, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 19, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 21, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 15, CastLike, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 19, CastLike, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, MemcpyFromHost, typeNameListDefault, supportedTypeListAll)},
|
||||
|
|
@ -1084,6 +1098,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
|
|||
{REG_INFO( 13, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)},
|
||||
{REG_INFO( 15, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)},
|
||||
{REG_INFO( 19, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)},
|
||||
{REG_INFO( 21, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)},
|
||||
{REG_INFO( 7, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)},
|
||||
{REG_INFO( 13, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)},
|
||||
{REG_INFO( 19, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)},
|
||||
|
|
@ -1246,6 +1261,8 @@ void RegisterDmlOperators(IMLOperatorRegistry* registry)
|
|||
|
||||
// Scalars
|
||||
if (bool(supportedTypes & SupportedTensorDataTypes::Float32)) edgeDescs.push_back(TensorEdgeDesc<float>());
|
||||
if (bool(supportedTypes & SupportedTensorDataTypes::UInt4 )) edgeDescs.push_back(TensorEdgeDesc<::MLUInt4x2>());
|
||||
if (bool(supportedTypes & SupportedTensorDataTypes::Int4 )) edgeDescs.push_back(TensorEdgeDesc<::MLInt4x2>());
|
||||
if (bool(supportedTypes & SupportedTensorDataTypes::UInt8 )) edgeDescs.push_back(TensorEdgeDesc<uint8_t>());
|
||||
if (bool(supportedTypes & SupportedTensorDataTypes::Int8 )) edgeDescs.push_back(TensorEdgeDesc<int8_t>());
|
||||
if (bool(supportedTypes & SupportedTensorDataTypes::UInt16 )) edgeDescs.push_back(TensorEdgeDesc<uint16_t>());
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
|
||||
#include "core/providers/dml/DmlExecutionProvider/inc/MLOperatorAuthor.h"
|
||||
#include "MLOperatorAuthorPrivate.h"
|
||||
#include "core/framework/int4.h"
|
||||
#include <gsl/gsl>
|
||||
#include <optional>
|
||||
|
||||
|
|
@ -20,6 +21,8 @@ namespace onnxruntime
|
|||
}
|
||||
|
||||
using MLFloat16 = onnxruntime::MLFloat16;
|
||||
using MLUInt4x2 = onnxruntime::Int4x2Base<false>;
|
||||
using MLInt4x2 = onnxruntime::Int4x2Base<true>;
|
||||
|
||||
//
|
||||
// Traits for numeric attribute types
|
||||
|
|
@ -43,6 +46,18 @@ struct MLTypeTraits<int32_t>
|
|||
static const MLOperatorTensorDataType TensorType = MLOperatorTensorDataType::Int32;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MLTypeTraits<onnxruntime::Int4x2Base<false>>
|
||||
{
|
||||
static const MLOperatorTensorDataType TensorType = MLOperatorTensorDataType::UInt4;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MLTypeTraits<onnxruntime::Int4x2Base<true>>
|
||||
{
|
||||
static const MLOperatorTensorDataType TensorType = MLOperatorTensorDataType::Int4;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MLTypeTraits<uint8_t>
|
||||
{
|
||||
|
|
|
|||
|
|
@ -1696,9 +1696,11 @@ using ShapeInferenceHelper_DepthToSpace = DepthToSpaceHelper;
|
|||
using ShapeInferenceHelper_Squeeze7 = VersionedOpsetHelper<SqueezeHelper, 7>;
|
||||
using ShapeInferenceHelper_Squeeze11 = VersionedOpsetHelper<SqueezeHelper, 11>;
|
||||
using ShapeInferenceHelper_Squeeze13 = VersionedOpsetHelper<SqueezeHelper, 13>;
|
||||
using ShapeInferenceHelper_Squeeze21 = VersionedOpsetHelper<SqueezeHelper, 21>;
|
||||
using ShapeInferenceHelper_Unsqueeze7 = VersionedOpsetHelper<UnsqueezeHelper, 7>;
|
||||
using ShapeInferenceHelper_Unsqueeze11 = VersionedOpsetHelper<UnsqueezeHelper, 11>;
|
||||
using ShapeInferenceHelper_Unsqueeze13 = VersionedOpsetHelper<UnsqueezeHelper, 13>;
|
||||
using ShapeInferenceHelper_Unsqueeze21 = VersionedOpsetHelper<UnsqueezeHelper, 21>;
|
||||
using ShapeInferenceHelper_EyeLike = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_Trilu = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_Col2Im = Col2ImHelper;
|
||||
|
|
@ -1708,6 +1710,7 @@ using ShapeInferenceHelper_Reshape7 = ReshapeHelper;
|
|||
using ShapeInferenceHelper_Reshape13 = ReshapeHelper;
|
||||
using ShapeInferenceHelper_Reshape14 = ReshapeHelper;
|
||||
using ShapeInferenceHelper_Reshape19 = ReshapeHelper;
|
||||
using ShapeInferenceHelper_Reshape21 = ReshapeHelper;
|
||||
using ShapeInferenceHelper_ConstantOfShape = ConstantOfShapeHelper;
|
||||
using ShapeInferenceHelper_Tile = TileHelper;
|
||||
using ShapeInferenceHelper_Resize10 = VersionedOpsetHelper<ResizeHelper, 10>;
|
||||
|
|
@ -1754,7 +1757,9 @@ using ShapeInferenceHelper_Asin = GetOutputShapeAsInputShapeHelper;
|
|||
using ShapeInferenceHelper_Atan = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_Affine = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_QuantizeLinear = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_QuantizeLinear21 = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_DequantizeLinear = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_DequantizeLinear21 = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_QLinearSigmoid = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_QAttention = QAttentionHelper;
|
||||
using ShapeInferenceHelper_Attention = AttentionHelper;
|
||||
|
|
|
|||
|
|
@ -437,6 +437,17 @@ namespace OperatorHelper
|
|||
static const int sc_sinceVer_ReduceMin = 20;
|
||||
}
|
||||
|
||||
namespace OnnxOperatorSet21
|
||||
{
|
||||
static const int sc_sinceVer_QuantizeLinear = 21;
|
||||
static const int sc_sinceVer_DequantizeLinear = 21;
|
||||
static const int sc_sinceVer_Squeeze = 21;
|
||||
static const int sc_sinceVer_Unsqueeze = 21;
|
||||
static const int sc_sinceVer_Reshape = 21;
|
||||
static const int sc_sinceVer_Cast = 21;
|
||||
static const int sc_sinceVer_Shape = 21;
|
||||
}
|
||||
|
||||
namespace MsftOperatorSet1
|
||||
{
|
||||
static const int sc_sinceVer_DmlFusedConv = 1;
|
||||
|
|
|
|||
|
|
@ -47,6 +47,19 @@ TEST(DequantizeLinearOpTest, Int4) {
|
|||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
|
||||
}
|
||||
|
||||
// scalar scale with int4
|
||||
TEST(DequantizeLinearOpTest, Int4NoZeroPoint) {
|
||||
OpTester test("DequantizeLinear", 21);
|
||||
std::vector<int64_t> dims{5};
|
||||
constexpr int unused_val = 0;
|
||||
|
||||
// Odd number of int4 values to test packing/unpacking
|
||||
test.AddInput<Int4x2>("x", dims, {Int4x2(-8, -3), Int4x2(1, 7), Int4x2(2, unused_val)});
|
||||
test.AddInput<float>("x_scale", {}, {2.0f});
|
||||
test.AddOutput<float>("y", dims, {-16.0f, -6.0f, 2.0f, 14.0f, 4.0f});
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
|
||||
}
|
||||
|
||||
// scalar zero & scale with uint4
|
||||
TEST(DequantizeLinearOpTest, UInt4) {
|
||||
OpTester test("DequantizeLinear", 21);
|
||||
|
|
@ -61,6 +74,19 @@ TEST(DequantizeLinearOpTest, UInt4) {
|
|||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
|
||||
}
|
||||
|
||||
// scalar scale with uint4
|
||||
TEST(DequantizeLinearOpTest, UInt4NoZeroPoint) {
|
||||
OpTester test("DequantizeLinear", 21);
|
||||
std::vector<int64_t> dims{5};
|
||||
constexpr int unused_val = 0;
|
||||
|
||||
// Odd number of uint4 values to test packing/unpacking
|
||||
test.AddInput<UInt4x2>("x", dims, {UInt4x2(0, 1), UInt4x2(3, 15), UInt4x2(2, unused_val)});
|
||||
test.AddInput<float>("x_scale", {}, {2.0f});
|
||||
test.AddOutput<float>("y", dims, {0.0f, 2.0f, 6.0f, 30.0f, 4.0f});
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
|
||||
}
|
||||
|
||||
// Test int16 DequantizeLinear (per tensor)
|
||||
TEST(DequantizeLinearOpTest, Int16) {
|
||||
OpTester test("DequantizeLinear", 21);
|
||||
|
|
|
|||
Loading…
Reference in a new issue