From ab9d4b366b7d6add87dd5ed18db299a5145884d3 Mon Sep 17 00:00:00 2001 From: Ye Wang <52801275+wangyems@users.noreply.github.com> Date: Fri, 20 Nov 2020 13:05:53 -0800 Subject: [PATCH] revert 262e9ef21dc319a8e7bece1f0d155364b17b019c (#5882) Co-authored-by: wangye --- onnxruntime/contrib_ops/cpu/bert/attention.cc | 11 +- .../contrib_ops/cpu/bert/attention_base.h | 1 - .../contrib_ops/cuda/bert/attention.cc | 8 +- .../core/graph/contrib_ops/contrib_defs.cc | 14 +-- .../test/contrib_ops/attention_op_test.cc | 114 ++++-------------- 5 files changed, 36 insertions(+), 112 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index d41524177b..01d17144ff 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -48,7 +48,6 @@ AttentionBase::AttentionBase(const OpKernelInfo& info) { num_heads_ = static_cast(num_heads); is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; - is_input_dim_swapped_ = info.GetAttrOrDefault("input_dimension_swapped", 0) == 1; } Status AttentionBase::CheckInputs(const TensorShape& input_shape, @@ -57,7 +56,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, const Tensor*& mask_index, const Tensor* past) const { // Input shapes: - // input : (batch_size, sequence_length, hidden_size) or (sequence_length, batch_size, hidden_size) + // input : (batch_size, sequence_length, hidden_size) // weights : (hidden_size, 3 * hidden_size) // bias : (3 * hidden_size) // mask_index : nullptr, (batch_size), (2 * batch_size), (batch_size, 1), (1, 1) or (batch_size, past_sequence_length + sequence_length) @@ -68,8 +67,8 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'input' is expected to have 3 dimensions, got ", dims.size()); } - int batch_size = is_input_dim_swapped_ ? static_cast(dims[1]) : static_cast(dims[0]); - int sequence_length = is_input_dim_swapped_ ? static_cast(dims[0]) : static_cast(dims[1]); + int batch_size = static_cast(dims[0]); + int sequence_length = static_cast(dims[1]); int hidden_size = static_cast(dims[2]); if (hidden_size % num_heads_ != 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, @@ -239,8 +238,8 @@ Status Attention::Compute(OpKernelContext* context) const { past)); const auto& shape = input->Shape().GetDims(); - const int batch_size = is_input_dim_swapped_ ? static_cast(shape[1]) : static_cast(shape[0]); - const int sequence_length = is_input_dim_swapped_ ? static_cast(shape[0]) : static_cast(shape[1]); + const int batch_size = static_cast(shape[0]); + const int sequence_length = static_cast(shape[1]); const int hidden_size = static_cast(shape[2]); const int head_size = hidden_size / num_heads_; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_base.h index 9190dddebc..b61fe2597b 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.h @@ -28,7 +28,6 @@ class AttentionBase { int num_heads_; // number of attention heads bool is_unidirectional_; // whether every token can only attend to previous tokens. - bool is_input_dim_swapped_; // whether the input_shape is (S, B, NH) instead of (B, S, NH) }; } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index 879f89e9d7..25a23a5111 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -42,11 +42,11 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), weights->Shape(), bias->Shape(), mask_index, past)); // Input and output shapes: - // Input 0 - input : (batch_size, sequence_length, hidden_size) or (sequence_length, batch_size, hidden_size) - // Output 0 - output : (batch_size, sequence_length, hidden_size) or (sequence_length, batch_size, hidden_size) + // Input 0 - input : (batch_size, sequence_length, hidden_size) + // Output 0 - output : (batch_size, sequence_length, hidden_size) const auto& shape = input->Shape(); - int batch_size = is_input_dim_swapped_ ? static_cast(shape[1]) : static_cast(shape[0]); - int sequence_length = is_input_dim_swapped_ ? static_cast(shape[0]) : static_cast(shape[1]); + int batch_size = static_cast(shape[0]); + int sequence_length = static_cast(shape[1]); int hidden_size = static_cast(shape[2]); int head_size = hidden_size / num_heads_; diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 60383fb862..084e145f5f 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -282,8 +282,7 @@ we also support other two formats: When input has right-side padding, mask_index where value of each element is the end position, or valid length of actual sequence excluding padding. When input has left-side padding, mask_index has shape (2 * batch_size), where the values are the exclusive end positions followed by the inclusive start positions. When unidirectional is 1, and each token only attend to previous tokens. For GPT-2, both past -and present state are optional. Present state could appear in output even when past state is not in input. When -input_dimension_swapped is 1, the input shape is (sequence_length, batch_size, hidden_size) which happens in Bart models, etc. +and present state are optional. Present state could appear in output even when past state is not in input. )DOC"; ONNX_CONTRIB_OPERATOR_SCHEMA(Attention) @@ -295,16 +294,12 @@ input_dimension_swapped is 1, the input shape is (sequence_length, batch_size, h "Whether every token can only attend to previous tokens. Default value is 0.", AttributeProto::INT, static_cast(0)) - .Attr("input_dimension_swapped", - "Whether input shape is (sequence_length, batch_size, hidden_size) instead of (batch_size, sequence_length, hidden_size). Default value is 0.", - AttributeProto::INT, - static_cast(0)) - .Input(0, "input", "3D input tensor with shape (batch_size, sequence_length, hidden_size) or (sequence_length, batch_size, hidden_size), hidden_size = num_heads * head_size", "T") + .Input(0, "input", "3D input tensor with shape (batch_size, sequence_length, hidden_size), hidden_size = num_heads * head_size", "T") .Input(1, "weight", "2D input tensor with shape (hidden_size, 3 * hidden_size)", "T") .Input(2, "bias", "1D input tensor with shape (3 * hidden_size)", "T") .Input(3, "mask_index", "Attention mask with shape (batch_size, past_sequence_length + sequence_length), or index with shape (batch_size) or (2 * batch_size).", "M", OpSchema::Optional) .Input(4, "past", "past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size).", "T", OpSchema::Optional) - .Output(0, "output", "3D output tensor with shape (batch_size, append_length, hidden_size) or (sequence_length, batch_size, hidden_size)", "T") + .Output(0, "output", "3D output tensor with shape (batch_size, append_length, hidden_size)", "T") .Output(1, "present", "present state for key and value with shape (2, batch_size, num_heads, past_sequence_length + sequence_length, head_size)", "T", OpSchema::Optional) .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") .TypeConstraint("M", {"tensor(int32)"}, "Constrain mask index to integer types") @@ -332,9 +327,6 @@ input_dimension_swapped is 1, the input shape is (sequence_length, batch_size, h } if (past_dims[3].has_dim_value() && input_dims[1].has_dim_value()) { - if (ctx.getAttribute("input_dimension_swapped")->i() != 0) { - fail_shape_inference("Past shall be work with input_dimension_swapped=0. aka when input shape equals to (B,S,NH)"); - } auto all_sequence_length = past_shape.dim(3).dim_value() + input_shape.dim(1).dim_value(); ONNX_NAMESPACE::TensorShapeProto present_shape; diff --git a/onnxruntime/test/contrib_ops/attention_op_test.cc b/onnxruntime/test/contrib_ops/attention_op_test.cc index 9f0ae570f8..2728a8ad5f 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/attention_op_test.cc @@ -17,19 +17,18 @@ enum MaskIndexType { }; static void RunAttentionTest( - const std::vector& input_data, // input: [batch_size, sequence_length, hidden_size] or [sequence_length, batch_size, hidden_size] + const std::vector& input_data, // input: [batch_size, sequence_length, hidden_size] const std::vector& weights_data, // weights: [hidden_size, 3 * hidden_size] bool is_weights_constant, const std::vector& bias_data, // bias: [3 * hidden_size] const std::vector& mask_index_data, // mask_index: [batch_size] or [batch_size, past_sequence_length + sequence_length] or empty - const std::vector& output_data, // output: [batch_size, sequence_length, hidden_size] or [sequence_length, batch_size, hidden_size] + const std::vector& output_data, // output: [batch_size, sequence_length, hidden_size] int batch_size, int sequence_length, int hidden_size, int number_of_heads, bool use_float16 = false, bool is_unidirectional = false, - bool is_input_dimension_swapped = false, bool use_past_state = false, int past_sequence_length = 0, const std::vector* past_data = nullptr, @@ -44,7 +43,6 @@ static void RunAttentionTest( OpTester tester("Attention", 1, onnxruntime::kMSDomain); tester.AddAttribute("num_heads", static_cast(number_of_heads)); tester.AddAttribute("unidirectional", static_cast(is_unidirectional ? 1 : 0)); - tester.AddAttribute("input_dimension_swapped", static_cast(is_input_dimension_swapped ? 1 : 0)); std::vector input_dims = {batch_size, sequence_length, hidden_size}; std::vector weights_dims = {hidden_size, 3 * hidden_size}; @@ -124,18 +122,17 @@ static void RunAttentionTest( } static void RunAttentionTest( - const std::vector& input_data, // input: [batch_size, sequence_length, hidden_size] or [sequence_length, batch_size, hidden_size] + const std::vector& input_data, // input: [batch_size, sequence_length, hidden_size] const std::vector& weights_data, // weights: [hidden_size, 3 * hidden_size] const std::vector& bias_data, // bias: [3 * hidden_size] const std::vector& mask_index_data, // mask_index: [batch_size] or [batch_size, past_sequence_length + sequence_length] or empty - const std::vector& output_data, // output: [batch_size, sequence_length, hidden_size] or [sequence_length, batch_size, hidden_size] + const std::vector& output_data, // output: [batch_size, sequence_length, hidden_size] int batch_size, int sequence_length, int hidden_size, int number_of_heads, bool use_float16 = false, bool is_unidirectional = false, - bool is_input_dimension_swapped = false, bool use_past_state = false, int past_sequence_length = 0, const std::vector* past_data = nullptr, @@ -143,12 +140,12 @@ static void RunAttentionTest( MaskIndexType mask_index_type = kMaskIndexEnd) { RunAttentionTest(input_data, weights_data, false, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, - use_float16, is_unidirectional, is_input_dimension_swapped, use_past_state, - past_sequence_length, past_data, present_data, mask_index_type); + use_float16, is_unidirectional, use_past_state, past_sequence_length, + past_data, present_data, mask_index_type); RunAttentionTest(input_data, weights_data, true, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, - use_float16, is_unidirectional, is_input_dimension_swapped, use_past_state, - past_sequence_length, past_data, present_data, mask_index_type); + use_float16, is_unidirectional, use_past_state, past_sequence_length, + past_data, present_data, mask_index_type); } TEST(AttentionTest, AttentionBatch1) { @@ -332,41 +329,6 @@ TEST(AttentionTest, AttentionNoMaskIndex) { batch_size, sequence_length, hidden_size, number_of_heads); } -TEST(AttentionTest, AttentionNoMaskInputShapeSwapped) { - int batch_size_as_sequence_length = 1; - int sequence_length_as_batch_size = 2; - int hidden_size = 4; - int number_of_heads = 2; - - std::vector input_data = { - 0.8f, -0.5f, 0.0f, 1.f, - 0.5f, 0.2f, 0.3f, -0.6f}; - - std::vector weight_data = { - 0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f, - 0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f, - 0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f, - 0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f}; - - std::vector bias_data = { - -0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f}; - - // No mask_index - std::vector mask_index_data = {}; - - std::vector output_data = { - 8.6899995803833008f, -0.13000002503395081f, 4.25f, 5.6499996185302734f, - -4.0900001525878906f, 0.42000001668930054f, -0.10999995470046997f, 0.56999993324279785f}; - - bool use_float16 = false; - bool is_unidirectional = false; - bool is_input_dimension_swapped = true; - - RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, - batch_size_as_sequence_length, sequence_length_as_batch_size, hidden_size, - number_of_heads, use_float16, is_unidirectional, is_input_dimension_swapped); -} - TEST(AttentionTest, AttentionUnidirectional) { int batch_size = 1; int sequence_length = 2; @@ -544,13 +506,12 @@ TEST(AttentionTest, AttentionEmptyPastState) { 0.053175069391727448f, 0.12795503437519073f, 0.11125634610652924f, -0.0510881207883358f, -0.55345797538757324f, -0.3045809268951416f, -0.36920222640037537f, 0.060108467936515808f, 0.28109729290008545f, 0.069518551230430603f, 0.45718482136726379f, -0.010400654748082161f, 0.0038009658455848694f, 0.29213353991508484f, -0.17697516083717346f, 0.27086889743804932f}; bool is_unidirectional = true; - bool is_input_dimension_swapped = false; bool use_past_state = true; int past_sequence_length = 0; RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, false, is_unidirectional, - is_input_dimension_swapped, use_past_state, past_sequence_length, &past_data, &present_data); + use_past_state, past_sequence_length, &past_data, &present_data); } TEST(AttentionTest, AttentionPastStateBatch1) { @@ -644,13 +605,12 @@ TEST(AttentionTest, AttentionPastStateBatch1) { 0.45075402f, 0.85365993f, 0.097346395f, 0.28859729f, 0.26926181f, 0.65922296f, -0.027254611f, -0.096526355f, 0.8177433f, 0.4212271f, 0.34352475f, 0.059609573f, 0.46556228f, 0.7226882f, -0.025281552f, -0.25482416f}; bool is_unidirectional = true; - bool is_input_dimension_swapped = false; bool use_past_state = true; int past_sequence_length = 3; RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, false, is_unidirectional, - is_input_dimension_swapped, use_past_state, past_sequence_length, &past_data, &present_data); + use_past_state, past_sequence_length, &past_data, &present_data); } TEST(AttentionTest, AttentionPastStateBatch2) { @@ -748,13 +708,12 @@ TEST(AttentionTest, AttentionPastStateBatch2) { 0.3113814f, 0.63999802f, 0.28603253f, 0.98899829f, 0.044405211f, 0.95105386f, -0.033968594f, -0.034833729f, 0.81278932f, 0.63969064f, 0.14494057f, 0.11349615f, 0.87086016f, 0.20983537f, 0.045759238f, -0.26863033f, 0.35107401f, 0.90144604f, 0.68950737f, 0.18928574f, 0.18029204f, 0.074517399f, -0.033201858f, -0.10592631f, 0.70763874f, 0.48440042f, 0.58114725f, 0.1048766f, 0.73694098f, 0.17766342f, -0.054369561f, -0.24562015f}; bool is_unidirectional = true; - bool is_input_dimension_swapped = false; bool use_past_state = true; int past_sequence_length = 3; RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, false, is_unidirectional, - is_input_dimension_swapped, use_past_state, past_sequence_length, &past_data, &present_data); + use_past_state, past_sequence_length, &past_data, &present_data); } TEST(AttentionTest, AttentionPastStateBatch2WithPadding) { @@ -852,13 +811,11 @@ TEST(AttentionTest, AttentionPastStateBatch2WithPadding) { 0.3113814f, 0.63999802f, 0.28603253f, 0.98899829f, 0.044405211f, 0.95105386f, -0.033968594f, -0.034833729f, 0.81278932f, 0.63969064f, 0.14494057f, 0.11349615f, 0.87086016f, 0.20983537f, 0.045759238f, -0.26863033f, 0.35107401f, 0.90144604f, 0.68950737f, 0.18928574f, 0.18029204f, 0.074517399f, -0.033201858f, -0.10592631f, 0.70763874f, 0.48440042f, 0.58114725f, 0.1048766f, 0.73694098f, 0.17766342f, -0.054369561f, -0.24562015f}; bool is_unidirectional = true; - bool is_input_dimension_swapped = false; bool use_past_state = true; int past_sequence_length = 3; RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, false, is_unidirectional, - is_input_dimension_swapped, use_past_state, past_sequence_length, &past_data, &present_data, - kMaskIndexEndAndStart); + use_past_state, past_sequence_length, &past_data, &present_data, kMaskIndexEndAndStart); } TEST(AttentionTest, AttentionBatch2MaskIndex2) { @@ -892,15 +849,13 @@ TEST(AttentionTest, AttentionBatch2MaskIndex2) { bool use_float16 = false; bool is_unidirectional = false; - bool is_input_dimension_swapped = false; bool use_past_state = false; int past_sequence_length = 0; const std::vector* past_data = nullptr; const std::vector* present_data = nullptr; RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, - use_float16, is_unidirectional, is_input_dimension_swapped, use_past_state, - past_sequence_length, past_data, present_data, kMaskIndexEndAndStart); + use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskIndexEndAndStart); } TEST(AttentionTest, AttentionRightPaddingMaskIndex2) { @@ -931,15 +886,13 @@ TEST(AttentionTest, AttentionRightPaddingMaskIndex2) { bool use_float16 = false; bool is_unidirectional = false; - bool is_input_dimension_swapped = false; bool use_past_state = false; int past_sequence_length = 0; const std::vector* past_data = nullptr; const std::vector* present_data = nullptr; RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, - use_float16, is_unidirectional, is_input_dimension_swapped, use_past_state, - past_sequence_length, past_data, present_data, kMaskIndexEndAndStart); + use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskIndexEndAndStart); } TEST(AttentionTest, AttentionLeftPaddingMaskIndex2) { @@ -970,15 +923,13 @@ TEST(AttentionTest, AttentionLeftPaddingMaskIndex2) { bool use_float16 = false; bool is_unidirectional = false; - bool is_input_dimension_swapped = false; bool use_past_state = false; int past_sequence_length = 0; const std::vector* past_data = nullptr; const std::vector* present_data = nullptr; RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, - use_float16, is_unidirectional, is_input_dimension_swapped, use_past_state, - past_sequence_length, past_data, present_data, kMaskIndexEndAndStart); + use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskIndexEndAndStart); } TEST(AttentionTest, AttentionBatch2LeftPaddingMaskIndex2) { @@ -1013,15 +964,13 @@ TEST(AttentionTest, AttentionBatch2LeftPaddingMaskIndex2) { bool use_float16 = false; bool is_unidirectional = false; - bool is_input_dimension_swapped = false; bool use_past_state = false; int past_sequence_length = 0; const std::vector* past_data = nullptr; const std::vector* present_data = nullptr; RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, - use_float16, is_unidirectional, is_input_dimension_swapped, use_past_state, - past_sequence_length, past_data, present_data, kMaskIndexEndAndStart); + use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskIndexEndAndStart); } TEST(AttentionTest, AttentionBatch2AttentionMask) { @@ -1056,15 +1005,13 @@ TEST(AttentionTest, AttentionBatch2AttentionMask) { bool use_float16 = false; bool is_unidirectional = false; - bool is_input_dimension_swapped = false; bool use_past_state = false; int past_sequence_length = 0; const std::vector* past_data = nullptr; const std::vector* present_data = nullptr; RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, - use_float16, is_unidirectional, is_input_dimension_swapped, use_past_state, - past_sequence_length, past_data, present_data, kMaskRaw); + use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskRaw); } TEST(AttentionTest, AttentionUnidirectionalAttentionMask) { @@ -1099,15 +1046,13 @@ TEST(AttentionTest, AttentionUnidirectionalAttentionMask) { bool use_float16 = false; bool is_unidirectional = true; - bool is_input_dimension_swapped = false; bool use_past_state = false; int past_sequence_length = 0; const std::vector* past_data = nullptr; const std::vector* present_data = nullptr; RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, - use_float16, is_unidirectional, is_input_dimension_swapped, use_past_state, - past_sequence_length, past_data, present_data, kMaskRaw); + use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskRaw); } TEST(AttentionTest, AttentionMask1DEndNoWord) { @@ -1142,7 +1087,6 @@ TEST(AttentionTest, AttentionMask1DEndNoWord) { bool use_float16 = false; bool is_unidirectional = false; - bool is_input_dimension_swapped = false; bool use_past_state = false; int past_sequence_length = 0; const std::vector* past_data = nullptr; @@ -1150,8 +1094,7 @@ TEST(AttentionTest, AttentionMask1DEndNoWord) { RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, - use_float16, is_unidirectional, is_input_dimension_swapped, use_past_state, - past_sequence_length, past_data, present_data, kMaskIndexEnd); + use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskIndexEnd); } TEST(AttentionTest, AttentionMask1DNoWord) { @@ -1186,7 +1129,6 @@ TEST(AttentionTest, AttentionMask1DNoWord) { bool use_float16 = false; bool is_unidirectional = false; - bool is_input_dimension_swapped = false; bool use_past_state = false; int past_sequence_length = 0; const std::vector* past_data = nullptr; @@ -1194,8 +1136,7 @@ TEST(AttentionTest, AttentionMask1DNoWord) { RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, - use_float16, is_unidirectional, is_input_dimension_swapped, use_past_state, - past_sequence_length, past_data, present_data, kMaskIndexEndAndStart); + use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskIndexEndAndStart); } TEST(AttentionTest, AttentionMask2DNoWord) { @@ -1230,7 +1171,6 @@ TEST(AttentionTest, AttentionMask2DNoWord) { bool use_float16 = false; bool is_unidirectional = false; - bool is_input_dimension_swapped = false; bool use_past_state = false; int past_sequence_length = 0; const std::vector* past_data = nullptr; @@ -1238,8 +1178,7 @@ TEST(AttentionTest, AttentionMask2DNoWord) { RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, - use_float16, is_unidirectional, is_input_dimension_swapped, use_past_state, - past_sequence_length, past_data, present_data, kMaskRaw); + use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskRaw); } @@ -1274,7 +1213,6 @@ TEST(AttentionTest, AttentionDummyMask2D) { bool use_float16 = false; bool is_unidirectional = false; - bool is_input_dimension_swapped = false; bool use_past_state = false; int past_sequence_length = 0; const std::vector* past_data = nullptr; @@ -1282,8 +1220,7 @@ TEST(AttentionTest, AttentionDummyMask2D) { RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, - use_float16, is_unidirectional, is_input_dimension_swapped, use_past_state, - past_sequence_length, past_data, present_data, kMaskDummy); + use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskDummy); } TEST(AttentionTest, AttentionMaskIndexOutOfRange) { @@ -1318,15 +1255,13 @@ TEST(AttentionTest, AttentionMaskIndexOutOfRange) { bool use_float16 = false; bool is_unidirectional = false; - bool is_input_dimension_swapped = false; bool use_past_state = false; int past_sequence_length = 0; const std::vector* past_data = nullptr; const std::vector* present_data = nullptr; RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, - use_float16, is_unidirectional, is_input_dimension_swapped, use_past_state, - past_sequence_length, past_data, present_data, kMaskIndexEndAndStart); + use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskIndexEndAndStart); } TEST(AttentionTest, AttentionPastState_dynamic) { @@ -1348,7 +1283,6 @@ TEST(AttentionTest, AttentionPastState_dynamic) { OpTester test("Attention", 1, onnxruntime::kMSDomain); test.AddAttribute("num_heads", 12); test.AddAttribute("unidirectional", 1); - test.AddAttribute("input_dimension_swapped", 0); test.AddInput("input", input_dims, input_data); test.AddInput("weight", weight_dims, weight_data); test.AddInput("bias", bias_dims, bias_data); @@ -1360,4 +1294,4 @@ TEST(AttentionTest, AttentionPastState_dynamic) { } } // namespace test -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file