mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
[DML] Register DML operators for opset 19 (#16939)
### Description <!-- Describe your changes. --> Register DML operators for opset 19. - Cast19 - Castlike19 - Constant19 - Equal19 - Identity19 - QuantizeLinear19 - DequantizeLinear19 - Reshape19 - Shape19 - Size ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> --------- Co-authored-by: linnealovespie <linneamay@microsoft.com>
This commit is contained in:
parent
77da2ef278
commit
24b74aebcb
7 changed files with 59 additions and 32 deletions
|
|
@ -922,10 +922,12 @@ Do not modify directly.*
|
|||
|BitwiseNot|*in* X:**T**<br> *out* Y:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|BitwiseOr|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|BitwiseXor|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|Cast|*in* input:**T1**<br> *out* output:**T2**|13+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|Cast|*in* input:**T1**<br> *out* output:**T2**|19+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||13+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||9+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||6+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|CastLike|*in* input:**T1**<br> *in* target_type:**T2**<br> *out* output:**T2**|15+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|CastLike|*in* input:**T1**<br> *in* target_type:**T2**<br> *out* output:**T2**|19+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||15+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|Ceil|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(float), tensor(float16)|
|
||||
|||6+|**T** = tensor(float), tensor(float16)|
|
||||
|Celu|*in* X:**T**<br> *out* Y:**T**|12+|**T** = tensor(float), tensor(float16)|
|
||||
|
|
@ -952,7 +954,8 @@ Do not modify directly.*
|
|||
|DepthToSpace|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|DequantizeLinear|*in* x:**T**<br> *in* x_scale:**tensor(float)**<br> *in* x_zero_point:**T**<br> *out* y:**tensor(float)**<br><br>or<br><br>*in* x:**T1**<br> *in* x_scale:**T2**<br> *in* x_zero_point:**T1**<br> *out* y:**T2**|13+|**T** = tensor(int32), tensor(int8), tensor(uint8)|
|
||||
|DequantizeLinear|*in* x:**T**<br> *in* x_scale:**tensor(float)**<br> *in* x_zero_point:**T**<br> *out* y:**tensor(float)**<br><br>or<br><br>*in* x:**T1**<br> *in* x_scale:**T2**<br> *in* x_zero_point:**T1**<br> *out* y:**T2**|19+|**T1** = tensor(int32), tensor(int8), tensor(uint8)<br/> **T2** = tensor(float), tensor(float16)|
|
||||
|||13+|**T** = tensor(int32), tensor(int8), tensor(uint8)|
|
||||
|||10+|**T** = tensor(int32), tensor(int8), tensor(uint8)|
|
||||
|Div|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|14+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|
|
@ -961,7 +964,8 @@ Do not modify directly.*
|
|||
|DynamicQuantizeLinear|*in* x:**T1**<br> *out* y:**T2**<br> *out* y_scale:**tensor(float)**<br> *out* y_zero_point:**T2**|11+|**T1** = tensor(float)<br/> **T2** = tensor(int8), tensor(uint8)|
|
||||
|Einsum|*in* Inputs:**T**<br> *out* Output:**T**|12+|**T** = tensor(float), tensor(float16)|
|
||||
|Elu|*in* X:**T**<br> *out* Y:**T**|6+|**T** = tensor(float), tensor(float16)|
|
||||
|Equal|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T1**|13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(bool)|
|
||||
|Equal|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T1**|19+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(bool)|
|
||||
|||13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(bool)|
|
||||
|||11+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(bool)|
|
||||
|||7+|**T** = tensor(float), tensor(float16)<br/> **T1** = tensor(bool)|
|
||||
|Erf|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(float), tensor(float16)|
|
||||
|
|
@ -1004,7 +1008,8 @@ Do not modify directly.*
|
|||
|Hardmax|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(float), tensor(float16)|
|
||||
|||11+|**T** = tensor(float), tensor(float16)|
|
||||
|||1+|**T** = tensor(float), tensor(float16)|
|
||||
|Identity|*in* input:**T**<br> *out* output:**T**<br><br>or<br><br>*in* input:**V**<br> *out* output:**V**|16+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|Identity|*in* input:**T**<br> *out* output:**T**<br><br>or<br><br>*in* input:**V**<br> *out* output:**V**|19+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||16+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||14+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|
|
@ -1099,7 +1104,8 @@ Do not modify directly.*
|
|||
|||7+|**T** = tensor(float), tensor(float16)|
|
||||
|QLinearConv|*in* x:**T1**<br> *in* x_scale:**tensor(float)**<br> *in* x_zero_point:**T1**<br> *in* w:**T2**<br> *in* w_scale:**tensor(float)**<br> *in* w_zero_point:**T2**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T3**<br> *in* B:**T4**<br> *out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(int8), tensor(uint8)<br/> **T4** = tensor(int32)|
|
||||
|QLinearMatMul|*in* a:**T1**<br> *in* a_scale:**tensor(float)**<br> *in* a_zero_point:**T1**<br> *in* b:**T2**<br> *in* b_scale:**tensor(float)**<br> *in* b_zero_point:**T2**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T3**<br> *out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(int8), tensor(uint8)|
|
||||
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**<br><br>or<br><br>*in* x:**T1**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|13+|**T1** = tensor(float), tensor(int32)<br/> **T2** = tensor(int8), tensor(uint8)|
|
||||
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**<br><br>or<br><br>*in* x:**T1**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|19+|**T1** = tensor(float), tensor(float16), tensor(int32)<br/> **T2** = tensor(int8), tensor(uint8)|
|
||||
|||13+|**T1** = tensor(float), tensor(int32)<br/> **T2** = tensor(int8), tensor(uint8)|
|
||||
|||10+|**T1** = tensor(float), tensor(int32)<br/> **T2** = tensor(int8), tensor(uint8)|
|
||||
|RNN|*in* X:**T**<br> *in* W:**T**<br> *in* R:**T**<br> *in* B:**T**<br> *in* sequence_lens:**T1**<br> *in* initial_h:**T**<br> *out* Y:**T**<br> *out* Y_h:**T**|14+|**T** = tensor(float), tensor(float16)|
|
||||
|||7+|**T** = tensor(float), tensor(float16)|
|
||||
|
|
@ -1150,7 +1156,8 @@ Do not modify directly.*
|
|||
|Relu|*in* X:**T**<br> *out* Y:**T**|14+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8)|
|
||||
|||13+|**T** = tensor(float), tensor(float16)|
|
||||
|||6+|**T** = tensor(float), tensor(float16)|
|
||||
|Reshape|*in* data:**T**<br> *in* shape:**tensor(int64)**<br> *out* reshaped:**T**<br><br>or<br><br>*in* data:**T**<br> *out* reshaped:**T**|14+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|Reshape|*in* data:**T**<br> *in* shape:**tensor(int64)**<br> *out* reshaped:**T**<br><br>or<br><br>*in* data:**T**<br> *out* reshaped:**T**|19+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||14+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||5+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|Resize|*in* X:**T**<br> *in* scales:**tensor(float)**<br> *out* Y:**T**<br><br>or<br><br>*in* X:**T1**<br> *in* roi:**T2**<br> *in* scales:**tensor(float)**<br> *in* sizes:**tensor(int64)**<br> *out* Y:**T1**|13+|**T1** = tensor(float), tensor(float16), tensor(int8), tensor(uint8)<br/> **T2** = tensor(float), tensor(float16)|
|
||||
|
|
@ -1178,7 +1185,8 @@ Do not modify directly.*
|
|||
|SequenceErase|*in* input_sequence:**S**<br> *in* position:**I**<br> *out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)<br/> **S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
|
||||
|SequenceInsert|*in* input_sequence:**S**<br> *in* tensor:**T**<br> *in* position:**I**<br> *out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)<br/> **S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
|
||||
|SequenceLength|*in* input_sequence:**S**<br> *out* length:**I**|11+|**I** = tensor(int64)<br/> **S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
|
||||
|Shape|*in* data:**T**<br> *out* shape:**T1**|15+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|
||||
|Shape|*in* data:**T**<br> *out* shape:**T1**|19+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|
||||
|||15+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|
||||
|||13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|
||||
|||1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|
||||
|Shrink|*in* input:**T**<br> *out* output:**T**|9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint8)|
|
||||
|
|
@ -1188,7 +1196,8 @@ Do not modify directly.*
|
|||
|||9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|Sin|*in* input:**T**<br> *out* output:**T**|7+|**T** = tensor(float), tensor(float16)|
|
||||
|Sinh|*in* input:**T**<br> *out* output:**T**|9+|**T** = tensor(float), tensor(float16)|
|
||||
|Size|*in* data:**T**<br> *out* size:**T1**|13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|
||||
|Size|*in* data:**T**<br> *out* size:**T1**|19+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|
||||
|||13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|
||||
|||1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|
||||
|Slice|*in* data:**T**<br> *in* starts:**Tind**<br> *in* ends:**Tind**<br> *in* axes:**Tind**<br> *in* steps:**Tind**<br> *out* output:**T**<br><br>or<br><br>*in* data:**T**<br> *out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|
||||
|||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ public:
|
|||
castDesc.OutputTensor = outputDescs.data();
|
||||
|
||||
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_CAST, &castDesc };
|
||||
|
||||
|
||||
SetDmlOperatorDesc(opDesc, kernelInfo);
|
||||
}
|
||||
|
||||
|
|
@ -49,5 +49,6 @@ public:
|
|||
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Cast, DmlOperatorCast);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(CastLike15, DmlOperatorCast);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(CastLike19, DmlOperatorCast);
|
||||
|
||||
} // namespace Dml
|
||||
|
|
|
|||
|
|
@ -487,7 +487,7 @@ public:
|
|||
Initialize(kernelInfo, kernelInputIndices, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0));
|
||||
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
|
||||
DML_ELEMENT_WISE_CONSTANT_POW_OPERATOR_DESC opDesc = {};
|
||||
opDesc.InputTensor = &inputDescs[0];
|
||||
|
|
@ -497,11 +497,11 @@ public:
|
|||
SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW, &opDesc}, kernelInfo);
|
||||
}
|
||||
else
|
||||
{
|
||||
{
|
||||
Initialize(kernelInfo, std::nullopt, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0));
|
||||
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
|
||||
DML_ELEMENT_WISE_POW_OPERATOR_DESC opDesc = {};
|
||||
opDesc.InputTensor = &inputDescs[0];
|
||||
|
|
@ -519,13 +519,16 @@ class DmlOperatorElementwiseQLinear : public DmlOperator
|
|||
public:
|
||||
DmlOperatorElementwiseQLinear(const MLOperatorKernelCreationContext& kernelInfo) : DmlOperator(kernelInfo)
|
||||
{
|
||||
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 3);
|
||||
|
||||
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() >= 2);
|
||||
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);
|
||||
|
||||
Initialize(kernelInfo, std::nullopt, std::nullopt);
|
||||
|
||||
std::vector<uint32_t> outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0);
|
||||
const uint32_t outputShapeDimCount = gsl::narrow_cast<uint32_t>(outputShape.size());
|
||||
|
||||
Initialize(kernelInfo, std::nullopt, std::nullopt);
|
||||
const DML_TENSOR_DATA_TYPE inputDataType = m_inputTensorDescs[0].GetDmlDataType();
|
||||
bool hasZeroPointTensor = kernelInfo.IsInputValid(2);
|
||||
|
||||
uint32_t axis = 0;
|
||||
|
||||
|
|
@ -541,9 +544,14 @@ public:
|
|||
axis = Dml::HandleNegativeAxis(signedAxis, outputShapeDimCount, /*validateAxis*/ false);
|
||||
}
|
||||
|
||||
// Explicitly reshape each of the inputs after the first input (scale and zero point tensors).
|
||||
// Explicitly reshape each of the inputs after the first input (scale tensor and optional zero point tensor).
|
||||
for (uint32_t index = 1, inputCount = gsl::narrow_cast<uint32_t>(m_inputTensorDescs.size()); index < inputCount; ++index)
|
||||
{
|
||||
if (!kernelInfo.IsInputValid(index))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
auto edgeDesc = kernelInfo.GetInputEdgeDescription(index);
|
||||
assert(edgeDesc.edgeType == MLOperatorEdgeType::Tensor);
|
||||
|
||||
|
|
@ -587,12 +595,8 @@ public:
|
|||
TOperatorDesc opDesc = {};
|
||||
opDesc.InputTensor = &inputDescs[0];
|
||||
opDesc.ScaleTensor = &inputDescs[1];
|
||||
opDesc.ZeroPointTensor = &inputDescs[2];
|
||||
opDesc.ZeroPointTensor = hasZeroPointTensor ? &inputDescs[2] : nullptr;
|
||||
opDesc.OutputTensor = &outputDescs[0];
|
||||
|
||||
TryConvertTensorToBroadcastScalar(kernelInfo, opDesc.ScaleTensor, 1);
|
||||
TryConvertTensorToBroadcastScalar(kernelInfo, opDesc.ZeroPointTensor, 2);
|
||||
|
||||
SetDmlOperatorDesc({ApiTraits::OperatorDescTraits<TOperatorDesc>::Type, &opDesc}, kernelInfo);
|
||||
}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -436,6 +436,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(FusedMatMul);
|
|||
DML_OP_EXTERN_CREATION_FUNCTION(FusedMatMulActivation);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Cast);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(CastLike15);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(CastLike19);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(MemcpyFromHost);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(MemcpyToHost);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(TopK7);
|
||||
|
|
@ -785,6 +786,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
|
|||
{REG_INFO_COPY(13, Identity, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_COPY(14, Identity, typeNameListDefaultV, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_COPY(16, Identity, typeNameListDefaultV, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_COPY(19, Identity, typeNameListDefaultV, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_COPY( 7, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_COPY( 9, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_COPY(11, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
|
||||
|
|
@ -798,6 +800,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
|
|||
{REG_INFO_COPY( 7, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO_COPY(13, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO_COPY(14, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO_COPY(19, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
|
||||
|
||||
// Elementwise
|
||||
{REG_INFO( 7, Sqrt, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
|
|
@ -857,8 +860,10 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
|
|||
{REG_INFO( 7, Affine, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 10, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 13, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 19, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear19, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 10, DequantizeLinear, typeNameListDefault, supportedTypeListDequantizeLinear, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 13, DequantizeLinear, typeNameListDefault, supportedTypeListDequantizeLinear, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 19, DequantizeLinear, typeNameListTwo, supportedTypeListDequantizeLinear19, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MS( 1, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear19, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MS( 1, DequantizeLinear, typeNameListTwo, supportedTypeListDequantizeLinear19, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Sign, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported)},
|
||||
|
|
@ -943,6 +948,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
|
|||
{REG_INFO( 7, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison7, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 13, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 19, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Not, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, And, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Or, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)},
|
||||
|
|
@ -1004,7 +1010,9 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
|
|||
{REG_INFO( 7, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 13, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 19, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 15, CastLike, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 19, CastLike, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, MemcpyFromHost, typeNameListDefault, supportedTypeListAll)},
|
||||
{REG_INFO( 7, MemcpyToHost, typeNameListDefault, supportedTypeListAll)},
|
||||
{REG_INFO_VER( 7, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported)},
|
||||
|
|
@ -1015,8 +1023,10 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
|
|||
{REG_INFO( 7, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)},
|
||||
{REG_INFO( 13, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)},
|
||||
{REG_INFO( 15, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)},
|
||||
{REG_INFO( 19, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)},
|
||||
{REG_INFO( 7, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)},
|
||||
{REG_INFO( 13, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)},
|
||||
{REG_INFO( 19, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)},
|
||||
{REG_INFO_DYNAMIC_OUTPUTS( 9, NonZero, typeNameListDefault, supportedTypeListNonZero, DmlGraphSupport::NotSupported)},
|
||||
{REG_INFO_DYNAMIC_OUTPUTS(13, NonZero, typeNameListDefault, supportedTypeListNonZero, DmlGraphSupport::NotSupported)},
|
||||
|
||||
|
|
|
|||
|
|
@ -1606,6 +1606,7 @@ using ShapeInferenceHelper_Expand = ExpandHelper;
|
|||
using ShapeInferenceHelper_Reshape7 = ReshapeHelper;
|
||||
using ShapeInferenceHelper_Reshape13 = ReshapeHelper;
|
||||
using ShapeInferenceHelper_Reshape14 = ReshapeHelper;
|
||||
using ShapeInferenceHelper_Reshape19 = ReshapeHelper;
|
||||
using ShapeInferenceHelper_ConstantOfShape = ConstantOfShapeHelper;
|
||||
using ShapeInferenceHelper_Tile = TileHelper;
|
||||
using ShapeInferenceHelper_Resize10 = VersionedOpsetHelper<ResizeHelper, 10>;
|
||||
|
|
@ -1725,6 +1726,7 @@ using ShapeInferenceHelper_Identity7 = GetOutputShapeAsInputShapeHelper;
|
|||
using ShapeInferenceHelper_Identity13 = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_Identity14 = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_Identity16 = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_Identity19 = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_MatMul = MatMulHelper;
|
||||
using ShapeInferenceHelper_MatMulInteger = MatMulHelper;
|
||||
using ShapeInferenceHelper_QLinearMatMul = QLinearMatMulHelper;
|
||||
|
|
@ -1750,6 +1752,7 @@ using ShapeInferenceHelper_CumSum14 = GetOutputShapeAsInputShapeHelper;
|
|||
using ShapeInferenceHelper_Range = RangeHelper;
|
||||
|
||||
using ShapeInferenceHelper_CastLike15 = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_CastLike19 = GetOutputShapeAsInputShapeHelper;
|
||||
|
||||
using ShapeInferenceHelper_DmlFusedConv = ConvHelper;
|
||||
using ShapeInferenceHelper_DmlFusedConvTranspose = ConvTransposeHelper;
|
||||
|
|
|
|||
|
|
@ -413,6 +413,16 @@ namespace OperatorHelper
|
|||
namespace OnnxOperatorSet19
|
||||
{
|
||||
static const int sc_sinceVer_AveragePool = 19;
|
||||
static const int sc_sinceVer_Cast = 19;
|
||||
static const int sc_sinceVer_CastLike = 19;
|
||||
static const int sc_sinceVer_Constant = 19;
|
||||
static const int sc_sinceVer_Equal = 19;
|
||||
static const int sc_sinceVer_Identity = 19;
|
||||
static const int sc_sinceVer_QuantizeLinear = 19;
|
||||
static const int sc_sinceVer_DequantizeLinear = 19;
|
||||
static const int sc_sinceVer_Reshape = 19;
|
||||
static const int sc_sinceVer_Shape = 19;
|
||||
static const int sc_sinceVer_Size = 19;
|
||||
}
|
||||
|
||||
namespace MsftOperatorSet1
|
||||
|
|
|
|||
|
|
@ -34,11 +34,6 @@ TEST(DequantizeLinearOpTest, Int8) {
|
|||
|
||||
// scalar zero & scale with int8
|
||||
TEST(DequantizeLinearOpTest, Int32) {
|
||||
// TODO: Unskip when fixed #41968513
|
||||
if (DefaultDmlExecutionProvider().get() != nullptr) {
|
||||
GTEST_SKIP() << "Skipping because of the following error: AbiCustomRegistry.cpp(507): The parameter is incorrect";
|
||||
}
|
||||
|
||||
OpTester test("DequantizeLinear", 10);
|
||||
std::vector<int64_t> dims{4};
|
||||
test.AddInput<int32_t>("x", dims, {-30, -3, 100, 127});
|
||||
|
|
@ -98,11 +93,6 @@ TEST(DequantizeLinearOpMLFloat16Test, Scalar) {
|
|||
|
||||
// dequantize without zero point
|
||||
TEST(DequantizeLinearOpTest, Without_Zero_Point) {
|
||||
// TODO: Unskip when fixed #41968513
|
||||
if (DefaultDmlExecutionProvider().get() != nullptr) {
|
||||
GTEST_SKIP() << "Skipping because of the following error: AbiCustomRegistry.cpp(507): The parameter is incorrect";
|
||||
}
|
||||
|
||||
OpTester test("DequantizeLinear", 10);
|
||||
test.AddInput<int8_t>("x", {}, {100});
|
||||
test.AddInput<float>("x_scale", {}, {2.0f});
|
||||
|
|
|
|||
Loading…
Reference in a new issue