From 5eac2c1f41a0e522c8ce70f5563a317c239e4c1d Mon Sep 17 00:00:00 2001 From: Ye Wang <52801275+wangyems@users.noreply.github.com> Date: Fri, 6 Jan 2023 17:32:58 -0800 Subject: [PATCH] relational attention bias cuda op (#14149) ### Description This cuda op implements the compute_bias() method in T5 Attention including the permutation. note: 1. bias_table needs to be saved in col-major. be careful when implementing fusion script 2. second input(sequence length) is placed on cpu. (using Shape node's output should be good) 3. the first dimension of output is 1, so extra_add_qk in attention should support broadcasting 4. compute_bias() only used in self-attn in t5 TODO: docs change will be applied later ### Motivation and Context It's part of the process of optimizing t5 attention as well as t5 based generation model Co-authored-by: Ubuntu --- docs/ContribOperators.md | 46 +++++ docs/OperatorKernels.md | 1 + .../cuda/bert/relative_attn_bias.cc | 72 ++++++++ .../cuda/bert/relative_attn_bias.h | 27 +++ .../cuda/bert/relative_attn_bias_impl.cu | 125 ++++++++++++++ .../cuda/bert/relative_attn_bias_impl.h | 26 +++ .../contrib_ops/cuda/cuda_contrib_kernels.cc | 4 + .../core/graph/contrib_ops/bert_defs.cc | 23 +++ onnxruntime/core/graph/contrib_ops/ms_opset.h | 2 + .../relative_attention_bias_test.cc | 159 ++++++++++++++++++ 10 files changed, 485 insertions(+) create mode 100644 onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc create mode 100644 onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.h create mode 100644 onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.h create mode 100644 onnxruntime/test/contrib_ops/relative_attention_bias_test.cc diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 507b1b259d..3fe9fb1d3a 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -71,6 +71,7 @@ Do not modify directly.* * com.microsoft.QuickGelu * com.microsoft.Range * com.microsoft.ReduceSumInteger + * com.microsoft.RelativePositionBias * com.microsoft.RemovePadding * com.microsoft.RestorePadding * com.microsoft.Rfft @@ -3704,6 +3705,51 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.RelativePositionBias** + + Compute binned relative position bias for T5 model. ref: https://arxiv.org/abs/1803.02155v2 + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
is_bidirectional : int
+
Default value is 0.
+
max_distance : int (required)
+
Max distance
+
+ +#### Inputs + +
+
bias_table : T
+
2D input tensor with shape (num_buckets, num_heads), COL-major(See UT for example)
+
query_length : U
+
The length of query. Self Attention requires query_length = key_length
+
key_length : U
+
The length of key.
+
+ +#### Outputs + +
+
output : T
+
4D output tensor with shape (1, num_heads, sequence_length, sequence_length)
+
+ +#### Type Constraints + +
+
T : tensor(float), tensor(float16)
+
Constrain input and output types to float or half tensors.
+
U : tensor(int64)
+
Constrain sequence_length to int tensors.
+
+ + ### **com.microsoft.RemovePadding** Compress transformer input by removing paddings. It assumes padding is on the right side of sequence. diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index d343d132aa..a2e91c136e 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -796,6 +796,7 @@ Do not modify directly.* |QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float16)
**T2** = tensor(int8), tensor(uint8)| |QuantizeWithOrder|*in* input:**F**
*in* scale_input:**S**
*out* output:**Q**|1+|**F** = tensor(float), tensor(float16)
**Q** = tensor(int8)
**S** = tensor(float)| |QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|RelativePositionBias|*in* bias_table:**T**
*in* query_length:**U**
*in* key_length:**U**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |RemovePadding|*in* input:**T**
*in* sequence_token_count:**M**
*out* output:**T**
*out* token_offset:**M**
*out* cumulated_seq_len:**M**
*out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)| |RestorePadding|*in* input:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |Rfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc new file mode 100644 index 0000000000..88ebbfe831 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc @@ -0,0 +1,72 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cuda_common.h" +#include "relative_attn_bias.h" +#include "relative_attn_bias_impl.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + RelativePositionBias , \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .InputMemoryType(OrtMemTypeCPUInput, 1) \ + .InputMemoryType(OrtMemTypeCPUInput, 2) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + RelPosAttnBias); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) + +using namespace ONNX_NAMESPACE; + +template +RelPosAttnBias::RelPosAttnBias(const OpKernelInfo& info) : CudaKernel(info) { + is_bidirectional_ = info.GetAttrOrDefault("is_bidirectional", 0) == 1; + + int64_t max_distance = 0; + ORT_ENFORCE(info.GetAttr("max_distance", &max_distance).IsOK() && max_distance > 0); + max_distance_ = static_cast(max_distance); +} + +template +Status RelPosAttnBias::ComputeInternal(OpKernelContext* context) const { + const Tensor* bias_table = context->Input(0); + const Tensor* query_length = context->Input(1); + const Tensor* key_length = context->Input(2); + + const auto& bias_table_dims = bias_table->Shape().GetDims(); + const int64_t num_buckets = bias_table_dims[0]; + const int64_t num_heads = bias_table_dims[1]; + + const int64_t query_len = *query_length->Data(); + const int64_t key_len = *key_length->Data(); + + if (query_len != key_len) { + ORT_THROW("Relatvie position bias currently only support query length equal to key length in Self Attention."); + } + + Tensor* output = context->Output(0, {1, num_heads, query_len, key_len}); + + typedef typename ToCudaType::MappedType CudaT; + + return LaunchRelPosAttnBiasKernel(Stream(context), + reinterpret_cast(output->template MutableData()), + reinterpret_cast(bias_table->template Data()), + static_cast(num_heads), + static_cast(query_len), + static_cast(num_buckets), + max_distance_, + is_bidirectional_); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.h b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.h new file mode 100644 index 0000000000..b9674f6f35 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.h @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/cuda_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using namespace onnxruntime::cuda; + +template +class RelPosAttnBias final : public CudaKernel { + public: + RelPosAttnBias(const OpKernelInfo& op_kernel_info); + Status ComputeInternal(OpKernelContext* ctx) const override; + + private: + int max_distance_; + bool is_bidirectional_; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu new file mode 100644 index 0000000000..a30fe5b3db --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu @@ -0,0 +1,125 @@ +/* +Copyright (c) Microsoft Corporation. +Licensed under the MIT License. +*/ +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "core/providers/cuda/cu_inc/common.cuh" +#include "contrib_ops/cuda/bert/relative_attn_bias_impl.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using namespace onnxruntime::cuda; + +template +__global__ void buildRelativeAttentionBias(T* relative_attention_bias, + const T* relative_attention_bias_table, + const int head_num, + const int seq_len, + const int num_bucket, + const bool is_bidirectional, + const int max_distance) { + const int head_id = blockIdx.x; + for (int seq_id = threadIdx.x; seq_id < seq_len * seq_len; seq_id += blockDim.x) { + int row_id = seq_id / seq_len; + int col_id = seq_id % seq_len; + + int relative_position = col_id - row_id; + + int relative_buckets = 0; + int tmp_num_bucket = num_bucket; + + if (is_bidirectional) { + tmp_num_bucket /= 2; + if (relative_position > 0) { + relative_buckets += tmp_num_bucket; + } else { + relative_position *= -1; + } + } else { + if (relative_position > 0) { + relative_position = 0; + } else { + relative_position *= -1; + } + } + + int max_exact = tmp_num_bucket / 2; + bool is_small = relative_position < max_exact; + + int relative_position_if_large = + max_exact + + (int)(logf(relative_position * 1.0f / max_exact) / logf((float)max_distance / max_exact) + * (tmp_num_bucket - max_exact)); + + relative_position_if_large = min(relative_position_if_large, tmp_num_bucket - 1); + + relative_buckets += is_small ? relative_position : relative_position_if_large; + + relative_attention_bias[head_id * seq_len * seq_len + seq_id] = + relative_attention_bias_table[head_id * num_bucket + relative_buckets]; + } +} + +template +Status LaunchRelPosAttnBiasKernel( + cudaStream_t stream, + T* output, + const T* bias_table, + const int num_heads, + const int seq_len, + const int num_bucket, + const int max_distance, + const bool is_bidirectional) +{ + dim3 grid(num_heads); + dim3 block(256); + + buildRelativeAttentionBias<<>>(output, + bias_table, + num_heads, + seq_len, + num_bucket, + is_bidirectional, + max_distance); + + return CUDA_CALL(cudaGetLastError()); +} + +template Status LaunchRelPosAttnBiasKernel(cudaStream_t stream, + float* output, + const float* bias_table, + const int num_heads, + const int seq_len, + const int num_bucket, + const int max_distance, + const bool is_bidirectional); + +template Status LaunchRelPosAttnBiasKernel(cudaStream_t stream, + half* output, + const half* bias_table, + const int num_heads, + const int seq_len, + const int num_bucket, + const int max_distance, + const bool is_bidirectional); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.h b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.h new file mode 100644 index 0000000000..a1efd2755f --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.h @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +Status LaunchRelPosAttnBiasKernel( + cudaStream_t stream, + T* output, + const T* bias_table, + const int num_heads, + const int seq_len, + const int num_bucket, + const int max_distance, + const bool is_bidirectional +); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 97995dc178..c1460a6a56 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -28,6 +28,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FusedMatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, FusedMatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, FusedMatMul); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RelativePositionBias); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RelativePositionBias); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RemovePadding); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RemovePadding); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RestorePadding); @@ -151,6 +153,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 1fa9142aaf..34b1317f05 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -436,6 +436,29 @@ ONNX_MS_OPERATOR_SET_SCHEMA( schema.BuildFunction(functionProto); return true; })); + +ONNX_MS_OPERATOR_SET_SCHEMA( + RelativePositionBias, 1, + OpSchema() + .SetDoc("Compute binned relative position bias for T5 model. ref: https://arxiv.org/abs/1803.02155v2") + .Attr("max_distance", "Max distance", AttributeProto::INT) + .Attr("is_bidirectional", "Default value is 0.", AttributeProto::INT, static_cast(0)) + .Input(0, "bias_table", "2D input tensor with shape (num_buckets, num_heads), COL-major(See UT for example)", "T") + .Input(1, "query_length", "The length of query. Self Attention requires query_length = key_length", "U") + .Input(2, "key_length", "The length of key.", "U") + .Output(0, "output", "4D output tensor with shape (1, num_heads, sequence_length, sequence_length)", "T") + .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float or half tensors.") + .TypeConstraint("U", {"tensor(int64)"}, "Constrain sequence_length to int tensors.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + auto& bias_table_shape = getInputShape(ctx, 0); + TensorShapeProto output_shape; + output_shape.add_dim()->set_dim_value(1); + *output_shape.add_dim() = bias_table_shape.dim(1); + output_shape.add_dim(); + output_shape.add_dim(); + updateOutputShape(ctx, 0, output_shape); + })); ONNX_MS_OPERATOR_SET_SCHEMA( SkipLayerNormalization, 1, diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index df30bdc6a3..673cbbaf8e 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -78,6 +78,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MaxpoolWithMask); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MurmurHash3); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, NGramRepeatBlock); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Pad); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RelativePositionBias); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RemovePadding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RestorePadding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Rfft); @@ -165,6 +166,7 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); diff --git a/onnxruntime/test/contrib_ops/relative_attention_bias_test.cc b/onnxruntime/test/contrib_ops/relative_attention_bias_test.cc new file mode 100644 index 0000000000..7722291bee --- /dev/null +++ b/onnxruntime/test/contrib_ops/relative_attention_bias_test.cc @@ -0,0 +1,159 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/common/tensor_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +static void RunRelativePositionBiasTest( + const std::vector& bias_table, // Shape = [num_buckets, num_heads] + const std::vector& sequence_length, // Shape = [1] + const std::vector& output_data, // Shape = [1, num_heads, sequence_length, sequence_length] + int max_distance, + int num_buckets, + int num_heads, + int seq_len, + int is_bidirectional, + bool use_float16 = false) { + int min_cuda_architecture = use_float16 ? 530 : 0; + + bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); + bool enable_cpu = false; + if (enable_cpu || enable_cuda) { + OpTester tester("RelativePositionBias", 1, onnxruntime::kMSDomain); + tester.AddAttribute("max_distance", static_cast(max_distance)); + tester.AddAttribute("is_bidirectional", static_cast(is_bidirectional)); + + std::vector bias_table_dims = {num_buckets, num_heads}; + std::vector sequence_length_dims = {1}; + std::vector output_dims = {1, num_heads, seq_len, seq_len}; + + if (use_float16) { + tester.AddInput("bias_table", bias_table_dims, ToFloat16(bias_table)); + tester.AddInput("query_length", sequence_length_dims, sequence_length); + tester.AddInput("key_length", sequence_length_dims, sequence_length); + tester.AddOutput("output", output_dims, ToFloat16(output_data)); + } else { + tester.AddInput("bias_table", bias_table_dims, bias_table); + tester.AddInput("query_length", sequence_length_dims, sequence_length); + tester.AddInput("key_length", sequence_length_dims, sequence_length); + tester.AddOutput("output", output_dims, output_data); + } + + if (enable_cuda) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + + if (enable_cpu) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + } +} + +TEST(RelativePositionBiasTest, RelativePositionBiasTest_FP32) { + int max_distance = 128; + int num_buckets = 4; + int num_heads = 2; + int seq_len = 2; + int is_bidirectional = 1; + + // Huggingface bias_table = [[1, 2], [3, 4], [5, 6], [7, 8]]. + // Save in col-major order in ORT + std::vector bias_table = {1.f, 3.f, 5.f, 7.f, 2.f, 4.f, 6.f, 8.f}; + std::vector sequence_length = {seq_len}; + + std::vector output_data = {1.f, 7.f, 3.f, 1.f, 2.f, 8.f, 4.f, 2.f}; + + RunRelativePositionBiasTest(bias_table, + sequence_length, + output_data, + max_distance, + num_buckets, + num_heads, + seq_len, + is_bidirectional); +} + +TEST(RelativePositionBiasTest, RelativePositionBiasTest_FP16) { + int max_distance = 128; + int num_buckets = 4; + int num_heads = 2; + int seq_len = 2; + int is_bidirectional = 1; + + // Huggingface bias_table = [[1, 2], [3, 4], [5, 6], [7, 8]]. + // Save in col-major order in ORT + std::vector bias_table = {1.f, 3.f, 5.f, 7.f, 2.f, 4.f, 6.f, 8.f}; + std::vector sequence_length = {seq_len}; + + std::vector output_data = {1.f, 7.f, 3.f, 1.f, 2.f, 8.f, 4.f, 2.f}; + + RunRelativePositionBiasTest(bias_table, + sequence_length, + output_data, + max_distance, + num_buckets, + num_heads, + seq_len, + is_bidirectional, + true); +} + +TEST(RelativePositionBiasTest, RelativePositionBiasTest2_FP16) { + int max_distance = 128; + int num_buckets = 4; + int num_heads = 3; + int seq_len = 2; + int is_bidirectional = 1; + + // Huggingface bias_table = [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]. + // Save in col-major order in ORT + std::vector bias_table = {1.f, 4.f, 7.f, 10.f, 2.f, 5.f, 8.f, 11.f, 3.f, 6.f, 9.f, 12.f}; + std::vector sequence_length = {seq_len}; + + std::vector output_data = {1.f, 10.f, 4.f, 1.f, 2.f, 11.f, 5.f, 2.f, 3.f, 12.f, 6.f, 3.f}; + + RunRelativePositionBiasTest(bias_table, + sequence_length, + output_data, + max_distance, + num_buckets, + num_heads, + seq_len, + is_bidirectional, + true); +} + +TEST(RelativePositionBiasTest, RelativePositionBiasTest_FP16_No_Bidirectional) { + int max_distance = 128; + int num_buckets = 4; + int num_heads = 3; + int seq_len = 2; + int is_bidirectional = 0; + + std::vector bias_table = {1.f, 4.f, 7.f, 10.f, 2.f, 5.f, 8.f, 11.f, 3.f, 6.f, 9.f, 12.f}; + std::vector sequence_length = {seq_len}; + + std::vector output_data = {1.f, 1.f, 4.f, 1.f, 2.f, 2.f, 5.f, 2.f, 3.f, 3.f, 6.f, 3.f}; + + RunRelativePositionBiasTest(bias_table, + sequence_length, + output_data, + max_distance, + num_buckets, + num_heads, + seq_len, + is_bidirectional, + true); +} + +} // namespace test +} // namespace onnxruntime