From 803837df63ab9b36a6aacf7c5686fabaef19a6dc Mon Sep 17 00:00:00 2001
From: Ye Wang <52801275+wangyems@users.noreply.github.com>
Date: Fri, 7 May 2021 20:17:29 -0700
Subject: [PATCH] Add 4dmask support for attention cuda kernel (#7591)
* checkin
* add 4dmask support in attention cuda op
* trim
* add comments
* fix build/test error
* review comments and add tests
* sync doc
* review comments
* minor change
---
docs/ContribOperators.md | 2 +-
docs/OperatorKernels.md | 288 +++++++++++++++++-
onnxruntime/contrib_ops/cpu/bert/attention.cc | 9 +-
.../contrib_ops/cpu/bert/attention_helper.h | 7 +
.../contrib_ops/cuda/bert/attention_impl.cu | 9 +-
.../contrib_ops/cuda/bert/attention_softmax.h | 35 ++-
.../core/graph/contrib_ops/contrib_defs.cc | 3 +-
.../test/contrib_ops/attention_op_test.cc | 69 ++++-
8 files changed, 395 insertions(+), 27 deletions(-)
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 905bd37ed2..fc9d26441c 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -103,7 +103,7 @@ This version of the operator has been available since version 1 of the 'com.micr
bias : T
1D input tensor with shape (3 * hidden_size)
mask_index (optional) : M
-Attention mask with shape (batch_size, past_sequence_length + sequence_length) or (batch_size, sequence_length, past_sequence_length + sequence_length), or index with shape (batch_size) or (2 * batch_size).
+Attention mask with shape (batch_size, 1, max_sequence_length, max_sequence_length), (batch_size, past_sequence_length + sequence_length)or (batch_size, sequence_length, past_sequence_length + sequence_length), or index with shape (batch_size) or (2 * batch_size).
past (optional) : T
past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size).
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 55a12473ea..7ad397f6ef 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -33,7 +33,8 @@
|AveragePool|(*in* X:**T**, *out* Y:**T**)|11+|**T** = tensor(float)|
|||10|**T** = tensor(float)|
|||[7, 9]|**T** = tensor(float)|
-|BatchNormalization|(*in* X:**T**, *in* scale:**T**, *in* B:**T**, *in* input_mean:**U**, *in* input_var:**U**, *out* Y:**T**, *out* running_mean:**U**, *out* running_var:**U**) or (*in* X:**T**, *in* scale:**T**, *in* B:**T**, *in* mean:**T**, *in* var:**T**, *out* Y:**T**, *out* mean:**T**, *out* var:**T**, *out* saved_mean:**T**, *out* saved_var:**T**)|9+|**T** = tensor(double), tensor(float)|
+|BatchNormalization|(*in* X:**T**, *in* scale:**T**, *in* B:**T**, *in* input_mean:**U**, *in* input_var:**U**, *out* Y:**T**, *out* running_mean:**U**, *out* running_var:**U**) or (*in* X:**T**, *in* scale:**T**, *in* B:**T**, *in* mean:**T**, *in* var:**T**, *out* Y:**T**, *out* mean:**T**, *out* var:**T**, *out* saved_mean:**T**, *out* saved_var:**T**)|14+|**T** = tensor(double), tensor(float)|
+|||[9, 13]|**T** = tensor(double), tensor(float)|
|||[7, 8]|**T** = tensor(double), tensor(float)|
|Binarizer|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(float)|
|BitShift|(*in* X:**T**, *in* Y:**T**, *out* Z:**T**)|11+|**T** = tensor(uint32), tensor(uint64), tensor(uint8)|
@@ -124,7 +125,7 @@
|Hardmax|(*in* input:**T**, *out* output:**T**)|13+|**T** = tensor(float)|
|||[11, 12]|**T** = tensor(float)|
|||[1, 10]|**T** = tensor(float)|
-|Identity|(*in* input:**T**, *out* output:**T**) or (*in* input:**V**, *out* output:**V**)|14+|**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|Identity|(*in* input:**T**, *out* output:**T**) or (*in* input:**V**, *out* output:**V**)|14+|**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||13|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|If|(*in* cond:**B**, *out* outputs:**V**)|13+|**B** = tensor(bool)
**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
@@ -141,7 +142,7 @@
|LSTM|(*in* X:**T**, *in* W:**T**, *in* R:**T**, *in* B:**T**, *in* sequence_lens:**T1**, *in* initial_h:**T**, *in* initial_c:**T**, *in* P:**T**, *out* Y:**T**, *out* Y_h:**T**, *out* Y_c:**T**)|7+|**T** = tensor(double), tensor(float)
**T1** = tensor(int32)|
|LabelEncoder|(*in* X:**T1**, *out* Y:**T2**)|2+|**T1** = tensor(float), tensor(int64), tensor(string)
**T2** = tensor(float), tensor(int64), tensor(string)|
|||1|**T1** = tensor(int64), tensor(string)
**T2** = tensor(int64), tensor(string)|
-|LayerNormalization|(*in* X:**T**, *in* scale:**T**, *in* B:**T**, *out* Y:**T**, *out* mean:**U**, *out* inv_std_var:**U**)|1+|**T** = tensor(double), tensor(float)|
+|LayerNormalization|(*in* X:**T**, *in* Scale:**T**, *in* B:**T**, *out* Y:**T**, *out* Mean:**U**, *out* InvStdDev:**U**)|1+|**T** = tensor(double), tensor(float)|
|LeakyRelu|(*in* X:**T**, *out* Y:**T**)|6+|**T** = tensor(float)|
|Less|(*in* A:**T**, *in* B:**T**, *out* C:**T1**)|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)|
|||[9, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)|
@@ -405,3 +406,284 @@
|Upsample|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(float)|
| |
| |
+
+
+## Operators implemented by CUDAExecutionProvider
+
+| Op Name | Parameters | OpSet Version | Types Supported |
+|---------|------------|---------------|-----------------|
+|**Operator Domain:** *ai.onnx.ml*||||
+|Abs|(*in* X:**T**, *out* Y:**T**)|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|Add|(*in* A:**T**, *in* B:**T**, *out* C:**T**)|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
+|||[7, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
+|Affine|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(double), tensor(float), tensor(float16)|
+|And|(*in* A:**T**, *in* B:**T**, *out* C:**T1**)|7+|**T** = tensor(bool)
**T1** = tensor(bool)|
+|ArgMax|(*in* data:**T**, *out* reduced:**tensor(int64)**)|11+|**T** = tensor(double), tensor(float), tensor(float16)|
+|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)|
+|ArgMin|(*in* data:**T**, *out* reduced:**tensor(int64)**)|11+|**T** = tensor(double), tensor(float), tensor(float16)|
+|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)|
+|AveragePool|(*in* X:**T**, *out* Y:**T**)|11+|**T** = tensor(double), tensor(float), tensor(float16)|
+|||10|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(float16)|
+|||[7, 9]|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(float16)|
+|BatchNormalization|(*in* X:**T**, *in* scale:**T**, *in* B:**T**, *in* input_mean:**U**, *in* input_var:**U**, *out* Y:**T**, *out* running_mean:**U**, *out* running_var:**U**) or (*in* X:**T**, *in* scale:**T**, *in* B:**T**, *in* mean:**T**, *in* var:**T**, *out* Y:**T**, *out* mean:**T**, *out* var:**T**, *out* saved_mean:**T**, *out* saved_var:**T**)|9+|**T** = tensor(double), tensor(float), tensor(float16)|
+|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)|
+|Cast|(*in* input:**T1**, *out* output:**T2**)|13+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||[9, 12]|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||[6, 8]|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|Ceil|(*in* X:**T**, *out* Y:**T**)|13+|**T** = tensor(double), tensor(float), tensor(float16)|
+|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
+|Clip|(*in* input:**T**, *in* min:**T**, *in* max:**T**, *out* output:**T**) or (*in* input:**T**, *out* output:**T**)|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int64), tensor(int8), tensor(uint64), tensor(uint8)|
+|||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int64), tensor(int8), tensor(uint64), tensor(uint8)|
+|||11|**T** = tensor(float)|
+|||[6, 10]|**T** = tensor(float)|
+|Compress|(*in* input:**T**, *in* condition:**T1**, *out* output:**T**)|11+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)|
+|||[9, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)|
+|Concat|(*in* inputs:**T**, *out* concat_result:**T**)|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||[4, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|ConstantOfShape|(*in* input:**T1**, *out* output:**T2**)|9+|**T1** = tensor(int64)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|Conv|(*in* X:**T**, *in* W:**T**, *in* B:**T**, *out* Y:**T**)|11+|**T** = tensor(double), tensor(float), tensor(float16)|
+|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)|
+|ConvTranspose|(*in* X:**T**, *in* W:**T**, *in* B:**T**, *out* Y:**T**)|11+|**T** = tensor(double), tensor(float), tensor(float16)|
+|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)|
+|Cos|(*in* input:**T**, *out* output:**T**)|7+|**T** = tensor(double), tensor(float), tensor(float16)|
+|Crop|(*in* input:**T**, *out* output:**T**)|1+|**T** = tensor(double), tensor(float), tensor(float16)|
+|CumSum|(*in* x:**T**, *in* axis:**T2**, *out* y:**T**)|14+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T2** = tensor(int32), tensor(int64)|
+|||[11, 13]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T2** = tensor(int32), tensor(int64)|
+|DequantizeLinear|(*in* x:**T**, *in* x_scale:**tensor(float)**, *in* x_zero_point:**T**, *out* y:**tensor(float)**)|10+|**T** = tensor(int8), tensor(uint8)|
+|Div|(*in* A:**T**, *in* B:**T**, *out* C:**T**)|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
+|||[7, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
+|Dropout|(*in* data:**T**, *in* ratio:**T1**, *in* training_mode:**T2**, *out* output:**T**, *out* mask:**T2**) or (*in* data:**T**, *out* output:**T**, *out* mask:**T**) or (*in* data:**T**, *out* output:**T**, *out* mask:**T1**)|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)|
+|||12|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)|
+|||[10, 11]|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bool)|
+|||[7, 9]|**T** = tensor(double), tensor(float), tensor(float16)|
+|DynamicSlice|(*in* data:**T**, *in* starts:**Tind**, *in* ends:**Tind**, *in* axes:**Tind**, *out* output:**T**)|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)|
+|Einsum|(*in* Inputs:**T**, *out* Output:**T**)|12+|**T** = tensor(double), tensor(float), tensor(float16)|
+|Elu|(*in* X:**T**, *out* Y:**T**)|6+|**T** = tensor(double), tensor(float), tensor(float16)|
+|Equal|(*in* A:**T**, *in* B:**T**, *out* C:**T1**)|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)|
+|||[11, 12]|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
+|||[7, 10]|**T** = tensor(bool), tensor(int32), tensor(int64)|
+|Erf|(*in* input:**T**, *out* output:**T**)|13+|**T** = tensor(double), tensor(float), tensor(float16)|
+|||[9, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
+|Exp|(*in* input:**T**, *out* output:**T**)|13+|**T** = tensor(double), tensor(float), tensor(float16)|
+|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
+|Expand|(*in* input:**T**, *in* shape:**tensor(int64)**, *out* output:**T**)|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||[8, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|EyeLike|(*in* input:**T1**, *out* output:**T2**)|9+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint64)
**T2** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint64)|
+|Flatten|(*in* input:**T**, *out* output:**T**)|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||[9, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||[1, 8]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|Floor|(*in* X:**T**, *out* Y:**T**)|13+|**T** = tensor(double), tensor(float), tensor(float16)|
+|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
+|GRU|(*in* X:**T**, *in* W:**T**, *in* R:**T**, *in* B:**T**, *in* sequence_lens:**T1**, *in* initial_h:**T**, *out* Y:**T**, *out* Y_h:**T**)|7+|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(int32)|
+|Gather|(*in* data:**T**, *in* indices:**Tind**, *out* output:**T**)|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)|
+|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)|
+|||[1, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)|
+|GatherElements|(*in* data:**T**, *in* indices:**Tind**, *out* output:**T**)|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)|
+|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)|
+|GatherND|(*in* data:**T**, *in* indices:**tensor(int64)**, *out* output:**T**)|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int64)
**Tind** = tensor(int64)|
+|||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int64)
**Tind** = tensor(int64)|
+|Gemm|(*in* A:**T**, *in* B:**T**, *in* C:**T**, *out* Y:**T**)|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
+|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
+|||[9, 10]|**T** = tensor(double), tensor(float), tensor(float16)|
+|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)|
+|GlobalAveragePool|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(double), tensor(float), tensor(float16)|
+|GlobalMaxPool|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(double), tensor(float), tensor(float16)|
+|Greater|(*in* A:**T**, *in* B:**T**, *out* C:**T1**)|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)|
+|||[9, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
+|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)|
+|GreaterOrEqual|(*in* A:**T**, *in* B:**T**, *out* C:**T1**)|12+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)|
+|HardSigmoid|(*in* X:**T**, *out* Y:**T**)|6+|**T** = tensor(double), tensor(float), tensor(float16)|
+|Identity|(*in* input:**T**, *out* output:**T**) or (*in* input:**V**, *out* output:**V**)|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|If|(*in* cond:**B**, *out* outputs:**V**)|13+|**B** = tensor(bool)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||[11, 12]|**B** = tensor(bool)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||[1, 10]|**B** = tensor(bool)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|ImageScaler|(*in* input:**T**, *out* output:**T**)|1+|**T** = tensor(double), tensor(float), tensor(float16)|
+|InstanceNormalization|(*in* input:**T**, *in* scale:**T**, *in* B:**T**, *out* output:**T**)|6+|**T** = tensor(double), tensor(float), tensor(float16)|
+|LRN|(*in* X:**T**, *out* Y:**T**)|13+|**T** = tensor(double), tensor(float), tensor(float16)|
+|||[1, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
+|LSTM|(*in* X:**T**, *in* W:**T**, *in* R:**T**, *in* B:**T**, *in* sequence_lens:**T1**, *in* initial_h:**T**, *in* initial_c:**T**, *in* P:**T**, *out* Y:**T**, *out* Y_h:**T**, *out* Y_c:**T**)|7+|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(int32)|
+|LayerNormalization|(*in* X:**T**, *in* Scale:**T**, *in* B:**T**, *out* Y:**T**, *out* Mean:**U**, *out* InvStdDev:**U**)|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float)|
+|LeakyRelu|(*in* X:**T**, *out* Y:**T**)|6+|**T** = tensor(double), tensor(float), tensor(float16)|
+|Less|(*in* A:**T**, *in* B:**T**, *out* C:**T1**)|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)|
+|||[9, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
+|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)|
+|LessOrEqual|(*in* A:**T**, *in* B:**T**, *out* C:**T1**)|12+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)|
+|Log|(*in* input:**T**, *out* output:**T**)|13+|**T** = tensor(double), tensor(float), tensor(float16)|
+|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
+|LogSoftmax|(*in* input:**T**, *out* output:**T**)|13+|**T** = tensor(double), tensor(float), tensor(float16)|
+|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
+|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)|
+|Loop|(*in* M:**I**, *in* cond:**B**, *in* v_initial:**V**, *out* v_final_and_scan_outputs:**V**)|13+|**B** = tensor(bool)
**I** = tensor(int64)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||[11, 12]|**B** = tensor(bool)
**I** = tensor(int64)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||[1, 10]|**B** = tensor(bool)
**I** = tensor(int64)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|MatMul|(*in* A:**T**, *in* B:**T**, *out* Y:**T**)|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
+|||[9, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
+|||[1, 8]|**T** = tensor(double), tensor(float), tensor(float16)|
+|MatMulInteger|(*in* A:**T1**, *in* B:**T2**, *in* a_zero_point:**T1**, *in* b_zero_point:**T2**, *out* Y:**T3**)|10+|**T1** = tensor(int8)
**T2** = tensor(int8)
**T3** = tensor(int32)|
+|Max|(*in* data_0:**T**, *out* max:**T**)|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
+|||12|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
+|||[6, 11]|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
+|MaxPool|(*in* X:**T**, *out* Y:**T**) or (*in* X:**T**, *out* Y:**T**, *out* Indices:**I**)|12+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int8), tensor(uint8)|
+|||11|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(float16)|
+|||10|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(float16)|
+|||[8, 9]|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(float16)|
+|||[1, 7]|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(float16)|
+|MemcpyFromHost|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|MemcpyToHost|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|Min|(*in* data_0:**T**, *out* min:**T**)|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
+|||12|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
+|||[6, 11]|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
+|Mul|(*in* A:**T**, *in* B:**T**, *out* C:**T**)|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
+|||[7, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
+|Neg|(*in* X:**T**, *out* Y:**T**)|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8)|
+|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8)|
+|NonZero|(*in* X:**T**, *out* Y:**tensor(int64)**)|13+|**T** = tensor(bool), tensor(float), tensor(int32), tensor(int64), tensor(uint8)|
+|||[9, 12]|**T** = tensor(bool), tensor(float), tensor(int32), tensor(int64), tensor(uint8)|
+|Not|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(bool)
**T1** = tensor(bool)|
+|OneHot|(*in* indices:**T1**, *in* depth:**T2**, *in* values:**T3**, *out* output:**T3**)|11+|**T1** = tensor(int32), tensor(int64)
**T2** = tensor(int32), tensor(int64)
**T3** = tensor(float), tensor(float16), tensor(int64)|
+|Or|(*in* A:**T**, *in* B:**T**, *out* C:**T1**)|7+|**T** = tensor(bool)
**T1** = tensor(bool)|
+|PRelu|(*in* X:**T**, *in* slope:**T**, *out* Y:**T**)|9+|**T** = tensor(double), tensor(float), tensor(float16)|
+|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)|
+|Pad|(*in* data:**T**, *in* pads:**tensor(int64)**, *in* constant_value:**T**, *out* output:**T**) or (*in* data:**T**, *out* output:**T**)|11+|**T** = tensor(double), tensor(float), tensor(float16)|
+|||[2, 10]|**T** = tensor(double), tensor(float), tensor(float16)|
+|ParametricSoftplus|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(double), tensor(float), tensor(float16)|
+|Pow|(*in* X:**T**, *in* Y:**T**, *out* Z:**T**) or (*in* X:**T**, *in* Y:**T1**, *out* Z:**T**)|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)
**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
+|||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)
**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
+|||[7, 11]|**T** = tensor(double), tensor(float), tensor(float16)|
+|QuantizeLinear|(*in* x:**T1**, *in* y_scale:**tensor(float)**, *in* y_zero_point:**T2**, *out* y:**T2**)|10+|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)|
+|RNN|(*in* X:**T**, *in* W:**T**, *in* R:**T**, *in* B:**T**, *in* sequence_lens:**T1**, *in* initial_h:**T**, *out* Y:**T**, *out* Y_h:**T**)|7+|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(int32)|
+|Range|(*in* start:**T**, *in* limit:**T**, *in* delta:**T**, *out* output:**T**)|11+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)|
+|Reciprocal|(*in* X:**T**, *out* Y:**T**)|13+|**T** = tensor(double), tensor(float), tensor(float16)|
+|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
+|ReduceL1|(*in* data:**T**, *out* reduced:**T**)|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)|
+|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)|
+|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)|
+|ReduceL2|(*in* data:**T**, *out* reduced:**T**)|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)|
+|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)|
+|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)|
+|ReduceLogSum|(*in* data:**T**, *out* reduced:**T**)|13+|**T** = tensor(double), tensor(float), tensor(float16)|
+|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
+|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)|
+|ReduceLogSumExp|(*in* data:**T**, *out* reduced:**T**)|13+|**T** = tensor(double), tensor(float), tensor(float16)|
+|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
+|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)|
+|ReduceMax|(*in* data:**T**, *out* reduced:**T**)|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)|
+|||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)|
+|||11|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
+|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
+|ReduceMean|(*in* data:**T**, *out* reduced:**T**)|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)|
+|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)|
+|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)|
+|ReduceMin|(*in* data:**T**, *out* reduced:**T**)|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int8), tensor(uint8)|
+|||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int8), tensor(uint8)|
+|||11|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)|
+|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)|
+|ReduceProd|(*in* data:**T**, *out* reduced:**T**)|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)|
+|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)|
+|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)|
+|ReduceSum|(*in* data:**T**, *in* axes:**tensor(int64)**, *out* reduced:**T**) or (*in* data:**T**, *out* reduced:**T**)|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
+|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
+|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
+|ReduceSumSquare|(*in* data:**T**, *out* reduced:**T**)|13+|**T** = tensor(double), tensor(float), tensor(float16)|
+|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
+|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)|
+|Relu|(*in* X:**T**, *out* Y:**T**)|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
+|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
+|Reshape|(*in* data:**T**, *in* shape:**tensor(int64)**, *out* reshaped:**T**) or (*in* data:**T**, *out* reshaped:**T**)|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)|
+|||[5, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)|
+|||[1, 4]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|Resize|(*in* X:**T**, *in* scales:**tensor(float)**, *out* Y:**T**) or (*in* X:**T1**, *in* roi:**T2**, *in* scales:**tensor(float)**, *in* sizes:**tensor(int64)**, *out* Y:**T1**)|13+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
+|||[11, 12]|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
+|||10|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
+|ReverseSequence|(*in* input:**T**, *in* sequence_lens:**tensor(int64)**, *out* Y:**T**)|10+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|RoiAlign|(*in* X:**T1**, *in* rois:**T1**, *in* batch_indices:**T2**, *out* Y:**T1**)|10+|**T** = tensor(double), tensor(float)
**T2** = tensor(int64)|
+|Round|(*in* X:**T**, *out* Y:**T**)|11+|**T** = tensor(double), tensor(float), tensor(float16)|
+|ScaledTanh|(*in* input:**T**, *out* output:**T**)|1+|**T** = tensor(double), tensor(float), tensor(float16)|
+|Scan|(*in* initial_state_and_scan_inputs:**V**, *out* final_state_and_scan_outputs:**V**) or (*in* sequence_lens:**I**, *in* initial_state_and_scan_inputs:**V**, *out* final_state_and_scan_outputs:**V**)|11+|**I** = tensor(int64)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||[9, 10]|**I** = tensor(int64)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||8|**I** = tensor(int64)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|Scatter|(*in* data:**T**, *in* indices:**Tind**, *in* updates:**T**, *out* output:**T**)|[9, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)|
+|ScatterElements|(*in* data:**T**, *in* indices:**Tind**, *in* updates:**T**, *out* output:**T**)|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)|
+|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)|
+|ScatterND|(*in* data:**T**, *in* indices:**tensor(int64)**, *in* updates:**T**, *out* output:**T**)|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|Selu|(*in* X:**T**, *out* Y:**T**)|6+|**T** = tensor(double), tensor(float), tensor(float16)|
+|Shape|(*in* data:**T**, *out* shape:**T1**)|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|
+|||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|
+|Shrink|(*in* input:**T**, *out* output:**T**)|9+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|Sigmoid|(*in* X:**T**, *out* Y:**T**)|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
+|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
+|SimplifiedLayerNormalization|(*in* X:**T**, *in* scale:**T**, *out* Y:**T**, *out* inv_std_var:**U**)|1+|**T** = tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float)|
+|Sin|(*in* input:**T**, *out* output:**T**)|7+|**T** = tensor(double), tensor(float), tensor(float16)|
+|Size|(*in* data:**T**, *out* size:**T1**)|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|
+|||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|
+|Slice|(*in* data:**T**, *in* starts:**Tind**, *in* ends:**Tind**, *in* axes:**Tind**, *in* steps:**Tind**, *out* output:**T**) or (*in* data:**T**, *out* output:**T**)|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(float), tensor(int32), tensor(int64)|
+|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(float), tensor(int32), tensor(int64)|
+|||10|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(float), tensor(int32), tensor(int64)|
+|||[1, 9]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(float), tensor(int32), tensor(int64)|
+|Softmax|(*in* input:**T**, *out* output:**T**)|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
+|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
+|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)|
+|Softplus|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(double), tensor(float), tensor(float16)|
+|Softsign|(*in* input:**T**, *out* output:**T**)|1+|**T** = tensor(double), tensor(float), tensor(float16)|
+|Split|(*in* input:**T**, *in* split:**T**, *out* outputs...:**T**) or (*in* input:**T**, *in* split:**tensor(int64)**, *out* outputs:**T**) or (*in* input:**T**, *out* outputs:**T**)|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||[2, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|Sqrt|(*in* X:**T**, *out* Y:**T**)|13+|**T** = tensor(double), tensor(float), tensor(float16)|
+|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
+|Squeeze|(*in* data:**T**, *in* axes:**tensor(int64)**, *out* squeezed:**T**) or (*in* data:**T**, *out* squeezed:**T**)|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||[1, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|Sub|(*in* A:**T**, *in* B:**T**, *out* C:**T**)|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
+|||[7, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
+|Sum|(*in* data_0:**T**, *out* sum:**T**)|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
+|||[8, 12]|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
+|||[6, 7]|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
+|Tanh|(*in* input:**T**, *out* output:**T**)|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
+|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
+|ThresholdedRelu|(*in* X:**T**, *out* Y:**T**)|10+|**T** = tensor(double), tensor(float), tensor(float16)|
+|||1+|**T** = tensor(double), tensor(float), tensor(float16)|
+|Tile|(*in* input:**T**, *in* repeats:**T1**, *out* output:**T**) or (*in* input:**T**, *in* tiles:**T**, *in* axis:**T**, *out* output:**T**)|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)
**T1** = tensor(int64)|
+|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)
**T1** = tensor(int64)|
+|TopK|(*in* X:**T**, *in* K:**tensor(int64)**, *out* Values:**T**, *out* Indices:**I**) or (*in* X:**T**, *out* Values:**T**, *out* Indices:**I**)|11+|**I** = tensor(int64)
**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||10|**I** = tensor(int64)
**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||[1, 9]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|Transpose|(*in* data:**T**, *out* transposed:**T**)|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|Unsqueeze|(*in* data:**T**, *in* axes:**tensor(int64)**, *out* expanded:**T**) or (*in* data:**T**, *out* expanded:**T**)|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||[1, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|Upsample|(*in* X:**T**, *in* scales:**tensor(float)**, *out* Y:**T**) or (*in* X:**T**, *out* Y:**T**)|9|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
+|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
+|Where|(*in* condition:**B**, *in* X:**T**, *in* Y:**T**, *out* output:**T**)|9+|**B** = tensor(bool)
**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint8)|
+|Xor|(*in* A:**T**, *in* B:**T**, *out* C:**T1**)|7+|**T** = tensor(bool)
**T1** = tensor(bool)|
+| |
+| |
+|**Operator Domain:** *com.microsoft*||||
+|Attention|(*in* input:**T**, *in* weight:**T**, *in* bias:**T**, *in* mask_index:**M**, *in* past:**T**, *out* output:**T**, *out* present:**T**)|1+|**T** = tensor(float), tensor(float16)|
+|BiasDropout|(*in* data:**T**, *in* bias:**T**, *in* residual:**T**, *in* ratio:**T1**, *in* training_mode:**T2**, *out* output:**T**, *out* mask:**T2**)|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)|
+|BiasGelu|(*in* A:**T**, *in* B:**T**, *out* C:**T**)|1+|**T** = tensor(double), tensor(float), tensor(float16)|
+|BiasSoftmax|(*in* data:**T**, *in* bias:**T**, *out* output:**T**)|1+|**T** = tensor(double), tensor(float), tensor(float16)|
+|ComplexMul|(*in* A:**T**, *in* B:**T**, *out* C:**T**)|1+|**T** = tensor(float), tensor(float16)|
+|ComplexMulConj|(*in* A:**T**, *in* B:**T**, *out* C:**T**)|1+|**T** = tensor(float), tensor(float16)|
+|ConvTransposeWithDynamicPads|(*in* X:**T**, *in* W:**T**, *in* Pads:**tensor(int64)**, *in* B:**T**, *out* Y:**T**)|1+|**T** = tensor(float)|
+|DequantizeLinear|(*in* x:**T1**, *in* x_scale:**T2**, *in* x_zero_point:**T1**, *out* y:**T2**)|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(float16)|
+|EmbedLayerNormalization|(*in* input_ids:**T1**, *in* segment_ids:**T1**, *in* word_embedding:**T**, *in* position_embedding:**T**, *in* segment_embedding:**T**, *in* gamma:**T**, *in* beta:**T**, *in* mask:**T1**, *out* output:**T**, *out* mask_index:**T1**)|1+|**T** = tensor(float), tensor(float16)|
+|FastGelu|(*in* X:**T**, *in* bias:**T**, *out* Y:**T**)|1+|**T** = tensor(bfloat16), tensor(float), tensor(float16)|
+|FusedConv|(*in* X:**T**, *in* W:**T**, *in* B:**T**, *in* Z:**T**, *out* Y:**T**)|1+|**T** = tensor(float)|
+|FusedMatMul|(*in* A:**T**, *in* B:**T**, *out* Y:**T**)|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
+|Gelu|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(double), tensor(float), tensor(float16)|
+|Inverse|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(double), tensor(float), tensor(float16)|
+|Irfft|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(double), tensor(float), tensor(float16)|
+|LongformerAttention|(*in* input:**T**, *in* weight:**T**, *in* bias:**T**, *in* mask:**T**, *in* global_weight:**T**, *in* global_bias:**T**, *in* global:**G**, *out* output:**T**)|1+|**T** = tensor(float), tensor(float16)|
+|QAttention|(*in* input:**T1**, *in* weight:**T2**, *in* bias:**T3**, *in* input_scale:**T3**, *in* weight_scale:**T3**, *in* mask_index:**T4**, *in* input_zero_point:**T1**, *in* weight_zero_point:**T2**, *in* past:**T3**, *out* output:**T3**, *out* present:**T3**)|1+|**T1** = tensor(int8)
**T2** = tensor(int8)
**T3** = tensor(float), tensor(float16)
**T4** = tensor(int32)|
+|QuantizeLinear|(*in* x:**T1**, *in* y_scale:**T1**, *in* y_zero_point:**T2**, *out* y:**T2**)|1+|**T1** = tensor(float16)
**T2** = tensor(int8), tensor(uint8)|
+|Rfft|(*in* X:**T**, *out* Y:**T**)|1+|**T** = tensor(double), tensor(float), tensor(float16)|
+|SkipLayerNormalization|(*in* input:**T**, *in* skip:**T**, *in* gamma:**T**, *in* beta:**T**, *in* bias:**T**, *out* output:**T**, *out* mean:**U**, *out* inv_std_var:**U**)|1+|**T** = tensor(float), tensor(float16)|
+|TransposeMatMul|(*in* A:**T**, *in* B:**T**, *out* Y:**T**)|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
+| |
+| |
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc
index 68e5b6815e..827692ede8 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc
@@ -146,8 +146,15 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
if (static_cast(mask_dims[0]) != batch_size || mask_dims[1] != sequence_length || static_cast(mask_dims[2]) != past_sequence_length + sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' with 3D data shall have shape batch_size x sequence_length x (past_sequence_length + sequence_length)");
}
+ } else if (mask_dims.size() == 4) {
+ if (static_cast(mask_dims[0]) != batch_size || mask_dims[1] != 1 || mask_dims[2] != mask_dims[3] || mask_dims[2] < past_sequence_length + sequence_length) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' with 4D data shall have shape batch_size x 1 x max_sequence_length x max_sequence_length)");
+ }
+ if (is_unidirectional_ == true) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' with 4D data shall have is_unidirectional_ set to false");
+ }
} else {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'mask_index' is expected to have 1, 2 or 3 dimensions, got ",
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'mask_index' is expected to have 1, 2, 3 or 4 dimensions, got ",
mask_dims.size());
}
}
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h
index 5afe24166a..cf9408a2db 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h
@@ -7,6 +7,7 @@
#include "core/util/math_cpuonly.h"
#include "core/common/safeint.h"
#include "core/platform/threadpool.h"
+#include "core/providers/common.h"
#include "core/mlas/inc/mlas.h"
using onnxruntime::concurrency::ThreadPool;
@@ -72,6 +73,12 @@ void PrepareMask(const int32_t* mask_index,
// mask_data has been filled with 0, and its shape is BxSxS*
T* p_mask = mask_data;
+ // 4D mask in Megatron GPT2 is currently not support in CPU kernel
+ if (nullptr != mask_index_dims && mask_index_dims->size() == 4) {
+ ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "4D mask in attention cpu kernel is not supported");
+ return;
+ }
+
// For 3D mask, convert values 0 to -10000.0, and 1 to 0.0, then apply unidirectional mask if any.
if (nullptr != mask_index_dims && mask_index_dims->size() == 3) {
for (int i = 0; i < batch_size * sequence_length * all_sequence_length; i++) {
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
index 8f7616c588..d232e70a30 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
@@ -117,7 +117,7 @@ bool QkvToContext(
// For raw attention mask, the scalar if 1/sqrt(H) is moved to softmax computation.
// TODO: move scalar to softmax computation since converting 1/Sqrt(H) to half might have loss in precision.
T alpha = use_raw_attention_mask ? one : (T)(rsqrt_head_size);
-
+
if (!CUBLAS_CALL(cublasGemmStridedBatchedHelper(
cublas, CUBLAS_OP_T, CUBLAS_OP_N, all_sequence_length, sequence_length, head_size, &alpha, k, head_size, present_size_per_batch,
q, head_size, size_per_batch, &zero, scratch1, all_sequence_length, temp_matrix_size, batches, prop))) {
@@ -125,8 +125,11 @@ bool QkvToContext(
}
// apply softmax and store result P to scratch2: BxNxSxS*
- if (use_raw_attention_mask) { // 2d or 3d attention mask
- if (!ComputeSoftmaxWithRawMask(stream, all_sequence_length, sequence_length, batch_size, num_heads, mask_index, scratch1, scratch2, is_unidirectional, rsqrt_head_size, static_cast(mask_index_dims->size()))) {
+ if (use_raw_attention_mask) { // 2d, 3d or 4d attention mask
+ const int mask_dimension = static_cast(mask_index_dims->size());
+ const int64_t max_sequence_length = mask_dimension == 4 ? mask_index_dims->at(3) : 0;
+ if (!ComputeSoftmaxWithRawMask(stream, all_sequence_length, sequence_length, batch_size, num_heads, mask_index, scratch1, scratch2, is_unidirectional,
+ rsqrt_head_size, mask_dimension, static_cast(max_sequence_length))) {
return false;
}
} else if (nullptr != mask_index) { // 1d mask index
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h
index e24d284559..a3cc1a9a7d 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h
@@ -161,12 +161,13 @@ __device__ inline void SoftmaxSmall(const int all_sequence_length,
template
__device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length,
const int sequence_length,
- const int* attention_mask, // 2D or 3D attention mask
+ const int* attention_mask, // 2D, 3D or 4D attention mask
const T* input,
T* output,
const bool is_unidirectional,
const float scalar,
- const int mask_dimension) {
+ const int mask_dimension,
+ const int max_sequence_length) {
using BlockReduce = cub::BlockReduce;
__shared__ typename BlockReduce::TempStorage tmp_storage;
@@ -180,7 +181,16 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length,
if (threadIdx.x < all_sequence_length) {
const int batch_index = blockIdx.y;
const int sequence_index = blockIdx.x % sequence_length;
- const int mask_offset = (mask_dimension == 2) ? batch_index * all_sequence_length + threadIdx.x : batch_index * sequence_length * all_sequence_length + sequence_index * all_sequence_length + threadIdx.x;
+ int mask_offset = 0;
+ if (mask_dimension == 2) {
+ mask_offset = batch_index * all_sequence_length + threadIdx.x;
+ } else if (mask_dimension == 3) {
+ mask_offset = (batch_index * sequence_length + sequence_index) * all_sequence_length + threadIdx.x;
+ } else if (mask_dimension == 4){
+ // Megatron code:
+ // ltor_mask = ltor_mask[..., (attention_scores.size(3)-hidden_states.size(1)):attention_scores.size(3), :attention_scores.size(3)]
+ mask_offset = (batch_index * max_sequence_length + all_sequence_length - sequence_length + sequence_index) * max_sequence_length + threadIdx.x;
+ }
const int& mask = attention_mask[mask_offset];
float mask_value = mask > 0 ? 0.0f : -10000.0f;
@@ -303,8 +313,9 @@ __global__ void MaskedSoftmaxKernel(const int all_sequence_length, const int seq
}
template
-__global__ void SoftmaxWithRawMaskSmallKernel(const int all_sequence_length, const int sequence_length, const int* attention_mask, const T* input, T* output, const bool is_unidirectional, const float scalar, const int mask_dimension) {
- SoftmaxWithRawMaskSmall(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension);
+__global__ void SoftmaxWithRawMaskSmallKernel(const int all_sequence_length, const int sequence_length, const int* attention_mask, const T* input, T* output,
+ const bool is_unidirectional, const float scalar, const int mask_dimension, const int max_sequence_length) {
+ SoftmaxWithRawMaskSmall(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension, max_sequence_length);
}
template
@@ -350,33 +361,33 @@ bool ComputeSoftmaxWithMask1D(cudaStream_t stream, const int all_sequence_length
template
bool ComputeSoftmaxWithRawMask(cudaStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads,
const int* attention_mask, const T* input, T* output, const bool is_unidirectional, const float scalar,
- const int mask_dimension) {
+ const int mask_dimension, const int max_sequence_length) {
const dim3 grid(sequence_length * num_heads, batch_size, 1);
if (all_sequence_length <= 32) {
const int blockSize = 32;
SoftmaxWithRawMaskSmallKernel
- <<>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension);
+ <<>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension, max_sequence_length);
} else if (all_sequence_length <= 64) {
const int blockSize = 64;
SoftmaxWithRawMaskSmallKernel
- <<>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension);
+ <<>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension, max_sequence_length);
} else if (all_sequence_length <= 128) {
const int blockSize = 128;
SoftmaxWithRawMaskSmallKernel
- <<>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension);
+ <<>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension, max_sequence_length);
} else if (all_sequence_length <= 256) {
const int blockSize = 256;
SoftmaxWithRawMaskSmallKernel
- <<>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension);
+ <<>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension, max_sequence_length);
} else if (all_sequence_length <= 512) {
const int blockSize = 512;
SoftmaxWithRawMaskSmallKernel
- <<>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension);
+ <<>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension, max_sequence_length);
} else if (all_sequence_length <= 1024) {
const int blockSize = 1024;
SoftmaxWithRawMaskSmallKernel
- <<>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension);
+ <<>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension, max_sequence_length);
} else {
ORT_THROW("Attention CUDA operator does not supported 2D attention mask with total sequence length > 1024.");
}
diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
index d95b95055c..903b39e753 100644
--- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
@@ -356,7 +356,8 @@ and present state are optional. Present state could appear in output even when p
.Input(0, "input", "3D input tensor with shape (batch_size, sequence_length, input_hidden_size)", "T")
.Input(1, "weight", "2D input tensor with shape (input_hidden_size, 3 * hidden_size), where hidden_size = num_heads * head_size", "T")
.Input(2, "bias", "1D input tensor with shape (3 * hidden_size)", "T")
- .Input(3, "mask_index", "Attention mask with shape (batch_size, past_sequence_length + sequence_length) or (batch_size, sequence_length, past_sequence_length + sequence_length), or index with shape (batch_size) or (2 * batch_size).", "M", OpSchema::Optional)
+ .Input(3, "mask_index", "Attention mask with shape (batch_size, 1, max_sequence_length, max_sequence_length), (batch_size, past_sequence_length + sequence_length)"
+ "or (batch_size, sequence_length, past_sequence_length + sequence_length), or index with shape (batch_size) or (2 * batch_size).", "M", OpSchema::Optional)
.Input(4, "past", "past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size).", "T", OpSchema::Optional)
.Output(0, "output", "3D output tensor with shape (batch_size, append_length, hidden_size)", "T")
.Output(1, "present", "present state for key and value with shape (2, batch_size, num_heads, past_sequence_length + sequence_length, head_size)", "T", OpSchema::Optional)
diff --git a/onnxruntime/test/contrib_ops/attention_op_test.cc b/onnxruntime/test/contrib_ops/attention_op_test.cc
index b4fec55736..97da885edd 100644
--- a/onnxruntime/test/contrib_ops/attention_op_test.cc
+++ b/onnxruntime/test/contrib_ops/attention_op_test.cc
@@ -14,7 +14,8 @@ enum MaskIndexType {
kMaskIndexEndAndStart,
kMaskRaw,
kMask3D,
- kMaskDummy // Dummy mask with shape [1, 1] or [batch_size, 1]
+ kMaskDummy, // Dummy mask with shape [1, 1] or [batch_size, 1]
+ kMask4D // Megatron GPT2 mask with shape [batch_size, 1, max_sequence_length, max_sequence_length]
};
static void RunAttentionTest(
@@ -35,12 +36,14 @@ static void RunAttentionTest(
const std::vector* past_data = nullptr,
const std::vector* present_data = nullptr,
MaskIndexType mask_index_type = kMaskIndexEnd,
- int input_hidden_size = 0) {
+ int input_hidden_size = 0,
+ int max_sequence_length = 0,
+ bool only_enable_cuda = false) {
input_hidden_size = (input_hidden_size == 0 ? hidden_size : input_hidden_size); // By default, no pruning.
int min_cuda_architecture = use_float16 ? 530 : 0;
bool enable_cuda = HasCudaEnvironment(min_cuda_architecture) && !is_weights_constant;
- bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()) && !use_float16;
+ bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()) && !use_float16 && !only_enable_cuda;
int head_size = hidden_size / number_of_heads;
if (enable_cpu || enable_cuda) {
@@ -57,6 +60,7 @@ static void RunAttentionTest(
std::vector mask_index_dims_3 = {batch_size, past_sequence_length + sequence_length};
std::vector mask_index_dims_4 = {batch_size, 1};
std::vector mask_index_dims_5 = {batch_size, sequence_length, past_sequence_length + sequence_length};
+ std::vector mask_index_dims_6 = {batch_size, 1, max_sequence_length, max_sequence_length};
std::vector mask_index_dims;
switch (mask_index_type) {
case kMaskIndexEnd:
@@ -74,6 +78,9 @@ static void RunAttentionTest(
case kMask3D:
mask_index_dims = mask_index_dims_5;
break;
+ case kMask4D:
+ mask_index_dims = mask_index_dims_6;
+ break;
default:
assert(0); // shall not reach here.
break;
@@ -146,15 +153,19 @@ static void RunAttentionTest(
const std::vector* past_data = nullptr,
const std::vector* present_data = nullptr,
MaskIndexType mask_index_type = kMaskIndexEnd,
- int input_hidden_size = 0) {
+ int input_hidden_size = 0,
+ int max_sequence_length = 0,
+ bool only_enable_cuda = false) {
RunAttentionTest(input_data, weights_data, false, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads,
use_float16, is_unidirectional, use_past_state, past_sequence_length,
- past_data, present_data, mask_index_type, input_hidden_size);
+ past_data, present_data, mask_index_type, input_hidden_size, max_sequence_length,
+ only_enable_cuda);
RunAttentionTest(input_data, weights_data, true, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads,
use_float16, is_unidirectional, use_past_state, past_sequence_length,
- past_data, present_data, mask_index_type, input_hidden_size);
+ past_data, present_data, mask_index_type, input_hidden_size, max_sequence_length,
+ only_enable_cuda);
}
TEST(AttentionTest, AttentionBatch1) {
@@ -1363,6 +1374,52 @@ TEST(AttentionTest, AttentionDummyMask2D) {
use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskDummy);
}
+TEST(AttentionTest, Attention4DMask) {
+ int batch_size = 1;
+ int sequence_length = 2;
+ int hidden_size = 4;
+ int number_of_heads = 2;
+
+ std::vector input_data = {
+ 0.5f, 0.2f, 0.3f, -0.6f,
+ 0.8f, -0.5f, 0.0f, 1.f};
+
+ std::vector weight_data = {
+ 0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f,
+ 0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f,
+ 0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f,
+ 0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f};
+
+ std::vector bias_data = {
+ -0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f};
+
+ // Test 4D mask Bx1xmax_Sxmax_S
+ std::vector mask_index_data = {
+ 0, 0, 0, 0,
+ 0, 1, 0, 0,
+ 0, 1, 1, 0,
+ 0, 1, 1, 1};
+
+ std::vector output_data = {
+ 3.97f, 0.073f, 4.25f, 5.65f,
+ 8.69f, -0.13f, 4.25f, 5.65f};
+
+ bool use_float16 = false;
+ bool is_unidirectional = false;
+ bool use_past_state = false;
+ int past_sequence_length = 0;
+ int input_hidden_size = 0;
+ int max_sequence_length = 4;
+ bool only_enable_cuda = true; // only support 4D mask in cuda
+ const std::vector* past_data = nullptr;
+ const std::vector* present_data = nullptr;
+ RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
+ batch_size, sequence_length, hidden_size, number_of_heads,
+ use_float16, is_unidirectional, use_past_state, past_sequence_length,
+ past_data, present_data, kMask4D, input_hidden_size, max_sequence_length,
+ only_enable_cuda);
+}
+
TEST(AttentionTest, AttentionMaskIndexOutOfRange) {
int batch_size = 2;
int sequence_length = 2;