mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
revert 262e9ef21d (#5882)
Co-authored-by: wangye <wangye@OrtDevTest2v100.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
parent
bef06dac93
commit
ab9d4b366b
5 changed files with 36 additions and 112 deletions
|
|
@ -48,7 +48,6 @@ AttentionBase::AttentionBase(const OpKernelInfo& info) {
|
|||
num_heads_ = static_cast<int>(num_heads);
|
||||
|
||||
is_unidirectional_ = info.GetAttrOrDefault<int64_t>("unidirectional", 0) == 1;
|
||||
is_input_dim_swapped_ = info.GetAttrOrDefault<int64_t>("input_dimension_swapped", 0) == 1;
|
||||
}
|
||||
|
||||
Status AttentionBase::CheckInputs(const TensorShape& input_shape,
|
||||
|
|
@ -57,7 +56,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
|
|||
const Tensor*& mask_index,
|
||||
const Tensor* past) const {
|
||||
// Input shapes:
|
||||
// input : (batch_size, sequence_length, hidden_size) or (sequence_length, batch_size, hidden_size)
|
||||
// input : (batch_size, sequence_length, hidden_size)
|
||||
// weights : (hidden_size, 3 * hidden_size)
|
||||
// bias : (3 * hidden_size)
|
||||
// mask_index : nullptr, (batch_size), (2 * batch_size), (batch_size, 1), (1, 1) or (batch_size, past_sequence_length + sequence_length)
|
||||
|
|
@ -68,8 +67,8 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
|
|||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'input' is expected to have 3 dimensions, got ",
|
||||
dims.size());
|
||||
}
|
||||
int batch_size = is_input_dim_swapped_ ? static_cast<int>(dims[1]) : static_cast<int>(dims[0]);
|
||||
int sequence_length = is_input_dim_swapped_ ? static_cast<int>(dims[0]) : static_cast<int>(dims[1]);
|
||||
int batch_size = static_cast<int>(dims[0]);
|
||||
int sequence_length = static_cast<int>(dims[1]);
|
||||
int hidden_size = static_cast<int>(dims[2]);
|
||||
if (hidden_size % num_heads_ != 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
|
|
@ -239,8 +238,8 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
|
|||
past));
|
||||
|
||||
const auto& shape = input->Shape().GetDims();
|
||||
const int batch_size = is_input_dim_swapped_ ? static_cast<int>(shape[1]) : static_cast<int>(shape[0]);
|
||||
const int sequence_length = is_input_dim_swapped_ ? static_cast<int>(shape[0]) : static_cast<int>(shape[1]);
|
||||
const int batch_size = static_cast<int>(shape[0]);
|
||||
const int sequence_length = static_cast<int>(shape[1]);
|
||||
const int hidden_size = static_cast<int>(shape[2]);
|
||||
const int head_size = hidden_size / num_heads_;
|
||||
|
||||
|
|
|
|||
|
|
@ -28,7 +28,6 @@ class AttentionBase {
|
|||
|
||||
int num_heads_; // number of attention heads
|
||||
bool is_unidirectional_; // whether every token can only attend to previous tokens.
|
||||
bool is_input_dim_swapped_; // whether the input_shape is (S, B, NH) instead of (B, S, NH)
|
||||
};
|
||||
|
||||
} // namespace contrib
|
||||
|
|
|
|||
|
|
@ -42,11 +42,11 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), weights->Shape(), bias->Shape(), mask_index, past));
|
||||
|
||||
// Input and output shapes:
|
||||
// Input 0 - input : (batch_size, sequence_length, hidden_size) or (sequence_length, batch_size, hidden_size)
|
||||
// Output 0 - output : (batch_size, sequence_length, hidden_size) or (sequence_length, batch_size, hidden_size)
|
||||
// Input 0 - input : (batch_size, sequence_length, hidden_size)
|
||||
// Output 0 - output : (batch_size, sequence_length, hidden_size)
|
||||
const auto& shape = input->Shape();
|
||||
int batch_size = is_input_dim_swapped_ ? static_cast<int>(shape[1]) : static_cast<int>(shape[0]);
|
||||
int sequence_length = is_input_dim_swapped_ ? static_cast<int>(shape[0]) : static_cast<int>(shape[1]);
|
||||
int batch_size = static_cast<int>(shape[0]);
|
||||
int sequence_length = static_cast<int>(shape[1]);
|
||||
int hidden_size = static_cast<int>(shape[2]);
|
||||
int head_size = hidden_size / num_heads_;
|
||||
|
||||
|
|
|
|||
|
|
@ -282,8 +282,7 @@ we also support other two formats: When input has right-side padding, mask_index
|
|||
where value of each element is the end position, or valid length of actual sequence excluding padding. When input has
|
||||
left-side padding, mask_index has shape (2 * batch_size), where the values are the exclusive end positions followed by
|
||||
the inclusive start positions. When unidirectional is 1, and each token only attend to previous tokens. For GPT-2, both past
|
||||
and present state are optional. Present state could appear in output even when past state is not in input. When
|
||||
input_dimension_swapped is 1, the input shape is (sequence_length, batch_size, hidden_size) which happens in Bart models, etc.
|
||||
and present state are optional. Present state could appear in output even when past state is not in input.
|
||||
)DOC";
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(Attention)
|
||||
|
|
@ -295,16 +294,12 @@ input_dimension_swapped is 1, the input shape is (sequence_length, batch_size, h
|
|||
"Whether every token can only attend to previous tokens. Default value is 0.",
|
||||
AttributeProto::INT,
|
||||
static_cast<int64_t>(0))
|
||||
.Attr("input_dimension_swapped",
|
||||
"Whether input shape is (sequence_length, batch_size, hidden_size) instead of (batch_size, sequence_length, hidden_size). Default value is 0.",
|
||||
AttributeProto::INT,
|
||||
static_cast<int64_t>(0))
|
||||
.Input(0, "input", "3D input tensor with shape (batch_size, sequence_length, hidden_size) or (sequence_length, batch_size, hidden_size), hidden_size = num_heads * head_size", "T")
|
||||
.Input(0, "input", "3D input tensor with shape (batch_size, sequence_length, hidden_size), hidden_size = num_heads * head_size", "T")
|
||||
.Input(1, "weight", "2D input tensor with shape (hidden_size, 3 * hidden_size)", "T")
|
||||
.Input(2, "bias", "1D input tensor with shape (3 * hidden_size)", "T")
|
||||
.Input(3, "mask_index", "Attention mask with shape (batch_size, past_sequence_length + sequence_length), or index with shape (batch_size) or (2 * batch_size).", "M", OpSchema::Optional)
|
||||
.Input(4, "past", "past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size).", "T", OpSchema::Optional)
|
||||
.Output(0, "output", "3D output tensor with shape (batch_size, append_length, hidden_size) or (sequence_length, batch_size, hidden_size)", "T")
|
||||
.Output(0, "output", "3D output tensor with shape (batch_size, append_length, hidden_size)", "T")
|
||||
.Output(1, "present", "present state for key and value with shape (2, batch_size, num_heads, past_sequence_length + sequence_length, head_size)", "T", OpSchema::Optional)
|
||||
.TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.")
|
||||
.TypeConstraint("M", {"tensor(int32)"}, "Constrain mask index to integer types")
|
||||
|
|
@ -332,9 +327,6 @@ input_dimension_swapped is 1, the input shape is (sequence_length, batch_size, h
|
|||
}
|
||||
|
||||
if (past_dims[3].has_dim_value() && input_dims[1].has_dim_value()) {
|
||||
if (ctx.getAttribute("input_dimension_swapped")->i() != 0) {
|
||||
fail_shape_inference("Past shall be work with input_dimension_swapped=0. aka when input shape equals to (B,S,NH)");
|
||||
}
|
||||
auto all_sequence_length = past_shape.dim(3).dim_value() + input_shape.dim(1).dim_value();
|
||||
|
||||
ONNX_NAMESPACE::TensorShapeProto present_shape;
|
||||
|
|
|
|||
|
|
@ -17,19 +17,18 @@ enum MaskIndexType {
|
|||
};
|
||||
|
||||
static void RunAttentionTest(
|
||||
const std::vector<float>& input_data, // input: [batch_size, sequence_length, hidden_size] or [sequence_length, batch_size, hidden_size]
|
||||
const std::vector<float>& input_data, // input: [batch_size, sequence_length, hidden_size]
|
||||
const std::vector<float>& weights_data, // weights: [hidden_size, 3 * hidden_size]
|
||||
bool is_weights_constant,
|
||||
const std::vector<float>& bias_data, // bias: [3 * hidden_size]
|
||||
const std::vector<int32_t>& mask_index_data, // mask_index: [batch_size] or [batch_size, past_sequence_length + sequence_length] or empty
|
||||
const std::vector<float>& output_data, // output: [batch_size, sequence_length, hidden_size] or [sequence_length, batch_size, hidden_size]
|
||||
const std::vector<float>& output_data, // output: [batch_size, sequence_length, hidden_size]
|
||||
int batch_size,
|
||||
int sequence_length,
|
||||
int hidden_size,
|
||||
int number_of_heads,
|
||||
bool use_float16 = false,
|
||||
bool is_unidirectional = false,
|
||||
bool is_input_dimension_swapped = false,
|
||||
bool use_past_state = false,
|
||||
int past_sequence_length = 0,
|
||||
const std::vector<float>* past_data = nullptr,
|
||||
|
|
@ -44,7 +43,6 @@ static void RunAttentionTest(
|
|||
OpTester tester("Attention", 1, onnxruntime::kMSDomain);
|
||||
tester.AddAttribute<int64_t>("num_heads", static_cast<int64_t>(number_of_heads));
|
||||
tester.AddAttribute<int64_t>("unidirectional", static_cast<int64_t>(is_unidirectional ? 1 : 0));
|
||||
tester.AddAttribute<int64_t>("input_dimension_swapped", static_cast<int64_t>(is_input_dimension_swapped ? 1 : 0));
|
||||
|
||||
std::vector<int64_t> input_dims = {batch_size, sequence_length, hidden_size};
|
||||
std::vector<int64_t> weights_dims = {hidden_size, 3 * hidden_size};
|
||||
|
|
@ -124,18 +122,17 @@ static void RunAttentionTest(
|
|||
}
|
||||
|
||||
static void RunAttentionTest(
|
||||
const std::vector<float>& input_data, // input: [batch_size, sequence_length, hidden_size] or [sequence_length, batch_size, hidden_size]
|
||||
const std::vector<float>& input_data, // input: [batch_size, sequence_length, hidden_size]
|
||||
const std::vector<float>& weights_data, // weights: [hidden_size, 3 * hidden_size]
|
||||
const std::vector<float>& bias_data, // bias: [3 * hidden_size]
|
||||
const std::vector<int32_t>& mask_index_data, // mask_index: [batch_size] or [batch_size, past_sequence_length + sequence_length] or empty
|
||||
const std::vector<float>& output_data, // output: [batch_size, sequence_length, hidden_size] or [sequence_length, batch_size, hidden_size]
|
||||
const std::vector<float>& output_data, // output: [batch_size, sequence_length, hidden_size]
|
||||
int batch_size,
|
||||
int sequence_length,
|
||||
int hidden_size,
|
||||
int number_of_heads,
|
||||
bool use_float16 = false,
|
||||
bool is_unidirectional = false,
|
||||
bool is_input_dimension_swapped = false,
|
||||
bool use_past_state = false,
|
||||
int past_sequence_length = 0,
|
||||
const std::vector<float>* past_data = nullptr,
|
||||
|
|
@ -143,12 +140,12 @@ static void RunAttentionTest(
|
|||
MaskIndexType mask_index_type = kMaskIndexEnd) {
|
||||
RunAttentionTest(input_data, weights_data, false, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads,
|
||||
use_float16, is_unidirectional, is_input_dimension_swapped, use_past_state,
|
||||
past_sequence_length, past_data, present_data, mask_index_type);
|
||||
use_float16, is_unidirectional, use_past_state, past_sequence_length,
|
||||
past_data, present_data, mask_index_type);
|
||||
RunAttentionTest(input_data, weights_data, true, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads,
|
||||
use_float16, is_unidirectional, is_input_dimension_swapped, use_past_state,
|
||||
past_sequence_length, past_data, present_data, mask_index_type);
|
||||
use_float16, is_unidirectional, use_past_state, past_sequence_length,
|
||||
past_data, present_data, mask_index_type);
|
||||
}
|
||||
|
||||
TEST(AttentionTest, AttentionBatch1) {
|
||||
|
|
@ -332,41 +329,6 @@ TEST(AttentionTest, AttentionNoMaskIndex) {
|
|||
batch_size, sequence_length, hidden_size, number_of_heads);
|
||||
}
|
||||
|
||||
TEST(AttentionTest, AttentionNoMaskInputShapeSwapped) {
|
||||
int batch_size_as_sequence_length = 1;
|
||||
int sequence_length_as_batch_size = 2;
|
||||
int hidden_size = 4;
|
||||
int number_of_heads = 2;
|
||||
|
||||
std::vector<float> input_data = {
|
||||
0.8f, -0.5f, 0.0f, 1.f,
|
||||
0.5f, 0.2f, 0.3f, -0.6f};
|
||||
|
||||
std::vector<float> weight_data = {
|
||||
0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f,
|
||||
0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f,
|
||||
0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f,
|
||||
0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f};
|
||||
|
||||
std::vector<float> bias_data = {
|
||||
-0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f};
|
||||
|
||||
// No mask_index
|
||||
std::vector<int32_t> mask_index_data = {};
|
||||
|
||||
std::vector<float> output_data = {
|
||||
8.6899995803833008f, -0.13000002503395081f, 4.25f, 5.6499996185302734f,
|
||||
-4.0900001525878906f, 0.42000001668930054f, -0.10999995470046997f, 0.56999993324279785f};
|
||||
|
||||
bool use_float16 = false;
|
||||
bool is_unidirectional = false;
|
||||
bool is_input_dimension_swapped = true;
|
||||
|
||||
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size_as_sequence_length, sequence_length_as_batch_size, hidden_size,
|
||||
number_of_heads, use_float16, is_unidirectional, is_input_dimension_swapped);
|
||||
}
|
||||
|
||||
TEST(AttentionTest, AttentionUnidirectional) {
|
||||
int batch_size = 1;
|
||||
int sequence_length = 2;
|
||||
|
|
@ -544,13 +506,12 @@ TEST(AttentionTest, AttentionEmptyPastState) {
|
|||
0.053175069391727448f, 0.12795503437519073f, 0.11125634610652924f, -0.0510881207883358f, -0.55345797538757324f, -0.3045809268951416f, -0.36920222640037537f, 0.060108467936515808f, 0.28109729290008545f, 0.069518551230430603f, 0.45718482136726379f, -0.010400654748082161f, 0.0038009658455848694f, 0.29213353991508484f, -0.17697516083717346f, 0.27086889743804932f};
|
||||
|
||||
bool is_unidirectional = true;
|
||||
bool is_input_dimension_swapped = false;
|
||||
bool use_past_state = true;
|
||||
int past_sequence_length = 0;
|
||||
|
||||
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads, false, is_unidirectional,
|
||||
is_input_dimension_swapped, use_past_state, past_sequence_length, &past_data, &present_data);
|
||||
use_past_state, past_sequence_length, &past_data, &present_data);
|
||||
}
|
||||
|
||||
TEST(AttentionTest, AttentionPastStateBatch1) {
|
||||
|
|
@ -644,13 +605,12 @@ TEST(AttentionTest, AttentionPastStateBatch1) {
|
|||
0.45075402f, 0.85365993f, 0.097346395f, 0.28859729f, 0.26926181f, 0.65922296f, -0.027254611f, -0.096526355f, 0.8177433f, 0.4212271f, 0.34352475f, 0.059609573f, 0.46556228f, 0.7226882f, -0.025281552f, -0.25482416f};
|
||||
|
||||
bool is_unidirectional = true;
|
||||
bool is_input_dimension_swapped = false;
|
||||
bool use_past_state = true;
|
||||
int past_sequence_length = 3;
|
||||
|
||||
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads, false, is_unidirectional,
|
||||
is_input_dimension_swapped, use_past_state, past_sequence_length, &past_data, &present_data);
|
||||
use_past_state, past_sequence_length, &past_data, &present_data);
|
||||
}
|
||||
|
||||
TEST(AttentionTest, AttentionPastStateBatch2) {
|
||||
|
|
@ -748,13 +708,12 @@ TEST(AttentionTest, AttentionPastStateBatch2) {
|
|||
0.3113814f, 0.63999802f, 0.28603253f, 0.98899829f, 0.044405211f, 0.95105386f, -0.033968594f, -0.034833729f, 0.81278932f, 0.63969064f, 0.14494057f, 0.11349615f, 0.87086016f, 0.20983537f, 0.045759238f, -0.26863033f, 0.35107401f, 0.90144604f, 0.68950737f, 0.18928574f, 0.18029204f, 0.074517399f, -0.033201858f, -0.10592631f, 0.70763874f, 0.48440042f, 0.58114725f, 0.1048766f, 0.73694098f, 0.17766342f, -0.054369561f, -0.24562015f};
|
||||
|
||||
bool is_unidirectional = true;
|
||||
bool is_input_dimension_swapped = false;
|
||||
bool use_past_state = true;
|
||||
int past_sequence_length = 3;
|
||||
|
||||
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads, false, is_unidirectional,
|
||||
is_input_dimension_swapped, use_past_state, past_sequence_length, &past_data, &present_data);
|
||||
use_past_state, past_sequence_length, &past_data, &present_data);
|
||||
}
|
||||
|
||||
TEST(AttentionTest, AttentionPastStateBatch2WithPadding) {
|
||||
|
|
@ -852,13 +811,11 @@ TEST(AttentionTest, AttentionPastStateBatch2WithPadding) {
|
|||
0.3113814f, 0.63999802f, 0.28603253f, 0.98899829f, 0.044405211f, 0.95105386f, -0.033968594f, -0.034833729f, 0.81278932f, 0.63969064f, 0.14494057f, 0.11349615f, 0.87086016f, 0.20983537f, 0.045759238f, -0.26863033f, 0.35107401f, 0.90144604f, 0.68950737f, 0.18928574f, 0.18029204f, 0.074517399f, -0.033201858f, -0.10592631f, 0.70763874f, 0.48440042f, 0.58114725f, 0.1048766f, 0.73694098f, 0.17766342f, -0.054369561f, -0.24562015f};
|
||||
|
||||
bool is_unidirectional = true;
|
||||
bool is_input_dimension_swapped = false;
|
||||
bool use_past_state = true;
|
||||
int past_sequence_length = 3;
|
||||
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads, false, is_unidirectional,
|
||||
is_input_dimension_swapped, use_past_state, past_sequence_length, &past_data, &present_data,
|
||||
kMaskIndexEndAndStart);
|
||||
use_past_state, past_sequence_length, &past_data, &present_data, kMaskIndexEndAndStart);
|
||||
}
|
||||
|
||||
TEST(AttentionTest, AttentionBatch2MaskIndex2) {
|
||||
|
|
@ -892,15 +849,13 @@ TEST(AttentionTest, AttentionBatch2MaskIndex2) {
|
|||
|
||||
bool use_float16 = false;
|
||||
bool is_unidirectional = false;
|
||||
bool is_input_dimension_swapped = false;
|
||||
bool use_past_state = false;
|
||||
int past_sequence_length = 0;
|
||||
const std::vector<float>* past_data = nullptr;
|
||||
const std::vector<float>* present_data = nullptr;
|
||||
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads,
|
||||
use_float16, is_unidirectional, is_input_dimension_swapped, use_past_state,
|
||||
past_sequence_length, past_data, present_data, kMaskIndexEndAndStart);
|
||||
use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskIndexEndAndStart);
|
||||
}
|
||||
|
||||
TEST(AttentionTest, AttentionRightPaddingMaskIndex2) {
|
||||
|
|
@ -931,15 +886,13 @@ TEST(AttentionTest, AttentionRightPaddingMaskIndex2) {
|
|||
|
||||
bool use_float16 = false;
|
||||
bool is_unidirectional = false;
|
||||
bool is_input_dimension_swapped = false;
|
||||
bool use_past_state = false;
|
||||
int past_sequence_length = 0;
|
||||
const std::vector<float>* past_data = nullptr;
|
||||
const std::vector<float>* present_data = nullptr;
|
||||
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads,
|
||||
use_float16, is_unidirectional, is_input_dimension_swapped, use_past_state,
|
||||
past_sequence_length, past_data, present_data, kMaskIndexEndAndStart);
|
||||
use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskIndexEndAndStart);
|
||||
}
|
||||
|
||||
TEST(AttentionTest, AttentionLeftPaddingMaskIndex2) {
|
||||
|
|
@ -970,15 +923,13 @@ TEST(AttentionTest, AttentionLeftPaddingMaskIndex2) {
|
|||
|
||||
bool use_float16 = false;
|
||||
bool is_unidirectional = false;
|
||||
bool is_input_dimension_swapped = false;
|
||||
bool use_past_state = false;
|
||||
int past_sequence_length = 0;
|
||||
const std::vector<float>* past_data = nullptr;
|
||||
const std::vector<float>* present_data = nullptr;
|
||||
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads,
|
||||
use_float16, is_unidirectional, is_input_dimension_swapped, use_past_state,
|
||||
past_sequence_length, past_data, present_data, kMaskIndexEndAndStart);
|
||||
use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskIndexEndAndStart);
|
||||
}
|
||||
|
||||
TEST(AttentionTest, AttentionBatch2LeftPaddingMaskIndex2) {
|
||||
|
|
@ -1013,15 +964,13 @@ TEST(AttentionTest, AttentionBatch2LeftPaddingMaskIndex2) {
|
|||
|
||||
bool use_float16 = false;
|
||||
bool is_unidirectional = false;
|
||||
bool is_input_dimension_swapped = false;
|
||||
bool use_past_state = false;
|
||||
int past_sequence_length = 0;
|
||||
const std::vector<float>* past_data = nullptr;
|
||||
const std::vector<float>* present_data = nullptr;
|
||||
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads,
|
||||
use_float16, is_unidirectional, is_input_dimension_swapped, use_past_state,
|
||||
past_sequence_length, past_data, present_data, kMaskIndexEndAndStart);
|
||||
use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskIndexEndAndStart);
|
||||
}
|
||||
|
||||
TEST(AttentionTest, AttentionBatch2AttentionMask) {
|
||||
|
|
@ -1056,15 +1005,13 @@ TEST(AttentionTest, AttentionBatch2AttentionMask) {
|
|||
|
||||
bool use_float16 = false;
|
||||
bool is_unidirectional = false;
|
||||
bool is_input_dimension_swapped = false;
|
||||
bool use_past_state = false;
|
||||
int past_sequence_length = 0;
|
||||
const std::vector<float>* past_data = nullptr;
|
||||
const std::vector<float>* present_data = nullptr;
|
||||
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads,
|
||||
use_float16, is_unidirectional, is_input_dimension_swapped, use_past_state,
|
||||
past_sequence_length, past_data, present_data, kMaskRaw);
|
||||
use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskRaw);
|
||||
}
|
||||
|
||||
TEST(AttentionTest, AttentionUnidirectionalAttentionMask) {
|
||||
|
|
@ -1099,15 +1046,13 @@ TEST(AttentionTest, AttentionUnidirectionalAttentionMask) {
|
|||
|
||||
bool use_float16 = false;
|
||||
bool is_unidirectional = true;
|
||||
bool is_input_dimension_swapped = false;
|
||||
bool use_past_state = false;
|
||||
int past_sequence_length = 0;
|
||||
const std::vector<float>* past_data = nullptr;
|
||||
const std::vector<float>* present_data = nullptr;
|
||||
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads,
|
||||
use_float16, is_unidirectional, is_input_dimension_swapped, use_past_state,
|
||||
past_sequence_length, past_data, present_data, kMaskRaw);
|
||||
use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskRaw);
|
||||
}
|
||||
|
||||
TEST(AttentionTest, AttentionMask1DEndNoWord) {
|
||||
|
|
@ -1142,7 +1087,6 @@ TEST(AttentionTest, AttentionMask1DEndNoWord) {
|
|||
|
||||
bool use_float16 = false;
|
||||
bool is_unidirectional = false;
|
||||
bool is_input_dimension_swapped = false;
|
||||
bool use_past_state = false;
|
||||
int past_sequence_length = 0;
|
||||
const std::vector<float>* past_data = nullptr;
|
||||
|
|
@ -1150,8 +1094,7 @@ TEST(AttentionTest, AttentionMask1DEndNoWord) {
|
|||
|
||||
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads,
|
||||
use_float16, is_unidirectional, is_input_dimension_swapped, use_past_state,
|
||||
past_sequence_length, past_data, present_data, kMaskIndexEnd);
|
||||
use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskIndexEnd);
|
||||
}
|
||||
|
||||
TEST(AttentionTest, AttentionMask1DNoWord) {
|
||||
|
|
@ -1186,7 +1129,6 @@ TEST(AttentionTest, AttentionMask1DNoWord) {
|
|||
|
||||
bool use_float16 = false;
|
||||
bool is_unidirectional = false;
|
||||
bool is_input_dimension_swapped = false;
|
||||
bool use_past_state = false;
|
||||
int past_sequence_length = 0;
|
||||
const std::vector<float>* past_data = nullptr;
|
||||
|
|
@ -1194,8 +1136,7 @@ TEST(AttentionTest, AttentionMask1DNoWord) {
|
|||
|
||||
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads,
|
||||
use_float16, is_unidirectional, is_input_dimension_swapped, use_past_state,
|
||||
past_sequence_length, past_data, present_data, kMaskIndexEndAndStart);
|
||||
use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskIndexEndAndStart);
|
||||
}
|
||||
|
||||
TEST(AttentionTest, AttentionMask2DNoWord) {
|
||||
|
|
@ -1230,7 +1171,6 @@ TEST(AttentionTest, AttentionMask2DNoWord) {
|
|||
|
||||
bool use_float16 = false;
|
||||
bool is_unidirectional = false;
|
||||
bool is_input_dimension_swapped = false;
|
||||
bool use_past_state = false;
|
||||
int past_sequence_length = 0;
|
||||
const std::vector<float>* past_data = nullptr;
|
||||
|
|
@ -1238,8 +1178,7 @@ TEST(AttentionTest, AttentionMask2DNoWord) {
|
|||
|
||||
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads,
|
||||
use_float16, is_unidirectional, is_input_dimension_swapped, use_past_state,
|
||||
past_sequence_length, past_data, present_data, kMaskRaw);
|
||||
use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskRaw);
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -1274,7 +1213,6 @@ TEST(AttentionTest, AttentionDummyMask2D) {
|
|||
|
||||
bool use_float16 = false;
|
||||
bool is_unidirectional = false;
|
||||
bool is_input_dimension_swapped = false;
|
||||
bool use_past_state = false;
|
||||
int past_sequence_length = 0;
|
||||
const std::vector<float>* past_data = nullptr;
|
||||
|
|
@ -1282,8 +1220,7 @@ TEST(AttentionTest, AttentionDummyMask2D) {
|
|||
|
||||
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads,
|
||||
use_float16, is_unidirectional, is_input_dimension_swapped, use_past_state,
|
||||
past_sequence_length, past_data, present_data, kMaskDummy);
|
||||
use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskDummy);
|
||||
}
|
||||
|
||||
TEST(AttentionTest, AttentionMaskIndexOutOfRange) {
|
||||
|
|
@ -1318,15 +1255,13 @@ TEST(AttentionTest, AttentionMaskIndexOutOfRange) {
|
|||
|
||||
bool use_float16 = false;
|
||||
bool is_unidirectional = false;
|
||||
bool is_input_dimension_swapped = false;
|
||||
bool use_past_state = false;
|
||||
int past_sequence_length = 0;
|
||||
const std::vector<float>* past_data = nullptr;
|
||||
const std::vector<float>* present_data = nullptr;
|
||||
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads,
|
||||
use_float16, is_unidirectional, is_input_dimension_swapped, use_past_state,
|
||||
past_sequence_length, past_data, present_data, kMaskIndexEndAndStart);
|
||||
use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskIndexEndAndStart);
|
||||
}
|
||||
|
||||
TEST(AttentionTest, AttentionPastState_dynamic) {
|
||||
|
|
@ -1348,7 +1283,6 @@ TEST(AttentionTest, AttentionPastState_dynamic) {
|
|||
OpTester test("Attention", 1, onnxruntime::kMSDomain);
|
||||
test.AddAttribute<int64_t>("num_heads", 12);
|
||||
test.AddAttribute<int64_t>("unidirectional", 1);
|
||||
test.AddAttribute<int64_t>("input_dimension_swapped", 0);
|
||||
test.AddInput<float>("input", input_dims, input_data);
|
||||
test.AddInput<float>("weight", weight_dims, weight_data);
|
||||
test.AddInput<float>("bias", bias_dims, bias_data);
|
||||
|
|
@ -1360,4 +1294,4 @@ TEST(AttentionTest, AttentionPastState_dynamic) {
|
|||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue