diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 507b1b259d..3fe9fb1d3a 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -71,6 +71,7 @@ Do not modify directly.*
* com.microsoft.QuickGelu
* com.microsoft.Range
* com.microsoft.ReduceSumInteger
+ * com.microsoft.RelativePositionBias
* com.microsoft.RemovePadding
* com.microsoft.RestorePadding
* com.microsoft.Rfft
@@ -3704,6 +3705,51 @@ This version of the operator has been available since version 1 of the 'com.micr
+### **com.microsoft.RelativePositionBias**
+
+ Compute binned relative position bias for T5 model. ref: https://arxiv.org/abs/1803.02155v2
+
+#### Version
+
+This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
+
+#### Attributes
+
+
+- is_bidirectional : int
+- Default value is 0.
+- max_distance : int (required)
+- Max distance
+
+
+#### Inputs
+
+
+- bias_table : T
+- 2D input tensor with shape (num_buckets, num_heads), COL-major(See UT for example)
+- query_length : U
+- The length of query. Self Attention requires query_length = key_length
+- key_length : U
+- The length of key.
+
+
+#### Outputs
+
+
+- output : T
+- 4D output tensor with shape (1, num_heads, sequence_length, sequence_length)
+
+
+#### Type Constraints
+
+
+- T : tensor(float), tensor(float16)
+- Constrain input and output types to float or half tensors.
+- U : tensor(int64)
+- Constrain sequence_length to int tensors.
+
+
+
### **com.microsoft.RemovePadding**
Compress transformer input by removing paddings. It assumes padding is on the right side of sequence.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index d343d132aa..a2e91c136e 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -796,6 +796,7 @@ Do not modify directly.*
|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float16)
**T2** = tensor(int8), tensor(uint8)|
|QuantizeWithOrder|*in* input:**F**
*in* scale_input:**S**
*out* output:**Q**|1+|**F** = tensor(float), tensor(float16)
**Q** = tensor(int8)
**S** = tensor(float)|
|QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
+|RelativePositionBias|*in* bias_table:**T**
*in* query_length:**U**
*in* key_length:**U**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|RemovePadding|*in* input:**T**
*in* sequence_token_count:**M**
*out* output:**T**
*out* token_offset:**M**
*out* cumulated_seq_len:**M**
*out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)|
|RestorePadding|*in* input:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|Rfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc
new file mode 100644
index 0000000000..88ebbfe831
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc
@@ -0,0 +1,72 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/providers/cuda/cuda_common.h"
+#include "relative_attn_bias.h"
+#include "relative_attn_bias_impl.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+#define REGISTER_KERNEL_TYPED(T) \
+ ONNX_OPERATOR_TYPED_KERNEL_EX( \
+ RelativePositionBias , \
+ kMSDomain, \
+ 1, \
+ T, \
+ kCudaExecutionProvider, \
+ (*KernelDefBuilder::Create()) \
+ .InputMemoryType(OrtMemTypeCPUInput, 1) \
+ .InputMemoryType(OrtMemTypeCPUInput, 2) \
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \
+ RelPosAttnBias);
+
+REGISTER_KERNEL_TYPED(float)
+REGISTER_KERNEL_TYPED(MLFloat16)
+
+using namespace ONNX_NAMESPACE;
+
+template
+RelPosAttnBias::RelPosAttnBias(const OpKernelInfo& info) : CudaKernel(info) {
+ is_bidirectional_ = info.GetAttrOrDefault("is_bidirectional", 0) == 1;
+
+ int64_t max_distance = 0;
+ ORT_ENFORCE(info.GetAttr("max_distance", &max_distance).IsOK() && max_distance > 0);
+ max_distance_ = static_cast(max_distance);
+}
+
+template
+Status RelPosAttnBias::ComputeInternal(OpKernelContext* context) const {
+ const Tensor* bias_table = context->Input(0);
+ const Tensor* query_length = context->Input(1);
+ const Tensor* key_length = context->Input(2);
+
+ const auto& bias_table_dims = bias_table->Shape().GetDims();
+ const int64_t num_buckets = bias_table_dims[0];
+ const int64_t num_heads = bias_table_dims[1];
+
+ const int64_t query_len = *query_length->Data();
+ const int64_t key_len = *key_length->Data();
+
+ if (query_len != key_len) {
+ ORT_THROW("Relatvie position bias currently only support query length equal to key length in Self Attention.");
+ }
+
+ Tensor* output = context->Output(0, {1, num_heads, query_len, key_len});
+
+ typedef typename ToCudaType::MappedType CudaT;
+
+ return LaunchRelPosAttnBiasKernel(Stream(context),
+ reinterpret_cast(output->template MutableData()),
+ reinterpret_cast(bias_table->template Data()),
+ static_cast(num_heads),
+ static_cast(query_len),
+ static_cast(num_buckets),
+ max_distance_,
+ is_bidirectional_);
+}
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.h b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.h
new file mode 100644
index 0000000000..b9674f6f35
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.h
@@ -0,0 +1,27 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include "core/common/common.h"
+#include "core/providers/cuda/cuda_kernel.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+using namespace onnxruntime::cuda;
+
+template
+class RelPosAttnBias final : public CudaKernel {
+ public:
+ RelPosAttnBias(const OpKernelInfo& op_kernel_info);
+ Status ComputeInternal(OpKernelContext* ctx) const override;
+
+ private:
+ int max_distance_;
+ bool is_bidirectional_;
+};
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu
new file mode 100644
index 0000000000..a30fe5b3db
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu
@@ -0,0 +1,125 @@
+/*
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT License.
+*/
+/*
+ * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "core/providers/cuda/cu_inc/common.cuh"
+#include "contrib_ops/cuda/bert/relative_attn_bias_impl.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+using namespace onnxruntime::cuda;
+
+template
+__global__ void buildRelativeAttentionBias(T* relative_attention_bias,
+ const T* relative_attention_bias_table,
+ const int head_num,
+ const int seq_len,
+ const int num_bucket,
+ const bool is_bidirectional,
+ const int max_distance) {
+ const int head_id = blockIdx.x;
+ for (int seq_id = threadIdx.x; seq_id < seq_len * seq_len; seq_id += blockDim.x) {
+ int row_id = seq_id / seq_len;
+ int col_id = seq_id % seq_len;
+
+ int relative_position = col_id - row_id;
+
+ int relative_buckets = 0;
+ int tmp_num_bucket = num_bucket;
+
+ if (is_bidirectional) {
+ tmp_num_bucket /= 2;
+ if (relative_position > 0) {
+ relative_buckets += tmp_num_bucket;
+ } else {
+ relative_position *= -1;
+ }
+ } else {
+ if (relative_position > 0) {
+ relative_position = 0;
+ } else {
+ relative_position *= -1;
+ }
+ }
+
+ int max_exact = tmp_num_bucket / 2;
+ bool is_small = relative_position < max_exact;
+
+ int relative_position_if_large =
+ max_exact
+ + (int)(logf(relative_position * 1.0f / max_exact) / logf((float)max_distance / max_exact)
+ * (tmp_num_bucket - max_exact));
+
+ relative_position_if_large = min(relative_position_if_large, tmp_num_bucket - 1);
+
+ relative_buckets += is_small ? relative_position : relative_position_if_large;
+
+ relative_attention_bias[head_id * seq_len * seq_len + seq_id] =
+ relative_attention_bias_table[head_id * num_bucket + relative_buckets];
+ }
+}
+
+template
+Status LaunchRelPosAttnBiasKernel(
+ cudaStream_t stream,
+ T* output,
+ const T* bias_table,
+ const int num_heads,
+ const int seq_len,
+ const int num_bucket,
+ const int max_distance,
+ const bool is_bidirectional)
+{
+ dim3 grid(num_heads);
+ dim3 block(256);
+
+ buildRelativeAttentionBias<<>>(output,
+ bias_table,
+ num_heads,
+ seq_len,
+ num_bucket,
+ is_bidirectional,
+ max_distance);
+
+ return CUDA_CALL(cudaGetLastError());
+}
+
+template Status LaunchRelPosAttnBiasKernel(cudaStream_t stream,
+ float* output,
+ const float* bias_table,
+ const int num_heads,
+ const int seq_len,
+ const int num_bucket,
+ const int max_distance,
+ const bool is_bidirectional);
+
+template Status LaunchRelPosAttnBiasKernel(cudaStream_t stream,
+ half* output,
+ const half* bias_table,
+ const int num_heads,
+ const int seq_len,
+ const int num_bucket,
+ const int max_distance,
+ const bool is_bidirectional);
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.h b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.h
new file mode 100644
index 0000000000..a1efd2755f
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.h
@@ -0,0 +1,26 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include "core/common/common.h"
+#include "core/providers/cuda/shared_inc/cuda_utils.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+template
+Status LaunchRelPosAttnBiasKernel(
+ cudaStream_t stream,
+ T* output,
+ const T* bias_table,
+ const int num_heads,
+ const int seq_len,
+ const int num_bucket,
+ const int max_distance,
+ const bool is_bidirectional
+);
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
index 97995dc178..c1460a6a56 100644
--- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
@@ -28,6 +28,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FusedMatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, FusedMatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, FusedMatMul);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RelativePositionBias);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RelativePositionBias);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RemovePadding);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RemovePadding);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RestorePadding);
@@ -151,6 +153,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
index 1fa9142aaf..34b1317f05 100644
--- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
@@ -436,6 +436,29 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
schema.BuildFunction(functionProto);
return true;
}));
+
+ONNX_MS_OPERATOR_SET_SCHEMA(
+ RelativePositionBias, 1,
+ OpSchema()
+ .SetDoc("Compute binned relative position bias for T5 model. ref: https://arxiv.org/abs/1803.02155v2")
+ .Attr("max_distance", "Max distance", AttributeProto::INT)
+ .Attr("is_bidirectional", "Default value is 0.", AttributeProto::INT, static_cast(0))
+ .Input(0, "bias_table", "2D input tensor with shape (num_buckets, num_heads), COL-major(See UT for example)", "T")
+ .Input(1, "query_length", "The length of query. Self Attention requires query_length = key_length", "U")
+ .Input(2, "key_length", "The length of key.", "U")
+ .Output(0, "output", "4D output tensor with shape (1, num_heads, sequence_length, sequence_length)", "T")
+ .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float or half tensors.")
+ .TypeConstraint("U", {"tensor(int64)"}, "Constrain sequence_length to int tensors.")
+ .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
+ propagateElemTypeFromInputToOutput(ctx, 0, 0);
+ auto& bias_table_shape = getInputShape(ctx, 0);
+ TensorShapeProto output_shape;
+ output_shape.add_dim()->set_dim_value(1);
+ *output_shape.add_dim() = bias_table_shape.dim(1);
+ output_shape.add_dim();
+ output_shape.add_dim();
+ updateOutputShape(ctx, 0, output_shape);
+ }));
ONNX_MS_OPERATOR_SET_SCHEMA(
SkipLayerNormalization, 1,
diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h
index df30bdc6a3..673cbbaf8e 100644
--- a/onnxruntime/core/graph/contrib_ops/ms_opset.h
+++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h
@@ -78,6 +78,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MaxpoolWithMask);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MurmurHash3);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, NGramRepeatBlock);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Pad);
+class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RelativePositionBias);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RemovePadding);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RestorePadding);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Rfft);
@@ -165,6 +166,7 @@ class OpSet_Microsoft_ver1 {
fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
+ fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
diff --git a/onnxruntime/test/contrib_ops/relative_attention_bias_test.cc b/onnxruntime/test/contrib_ops/relative_attention_bias_test.cc
new file mode 100644
index 0000000000..7722291bee
--- /dev/null
+++ b/onnxruntime/test/contrib_ops/relative_attention_bias_test.cc
@@ -0,0 +1,159 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "gtest/gtest.h"
+#include "test/common/tensor_op_test_utils.h"
+#include "test/common/cuda_op_test_utils.h"
+#include "test/providers/provider_test_utils.h"
+
+namespace onnxruntime {
+namespace test {
+
+static void RunRelativePositionBiasTest(
+ const std::vector& bias_table, // Shape = [num_buckets, num_heads]
+ const std::vector& sequence_length, // Shape = [1]
+ const std::vector& output_data, // Shape = [1, num_heads, sequence_length, sequence_length]
+ int max_distance,
+ int num_buckets,
+ int num_heads,
+ int seq_len,
+ int is_bidirectional,
+ bool use_float16 = false) {
+ int min_cuda_architecture = use_float16 ? 530 : 0;
+
+ bool enable_cuda = HasCudaEnvironment(min_cuda_architecture);
+ bool enable_cpu = false;
+ if (enable_cpu || enable_cuda) {
+ OpTester tester("RelativePositionBias", 1, onnxruntime::kMSDomain);
+ tester.AddAttribute("max_distance", static_cast(max_distance));
+ tester.AddAttribute("is_bidirectional", static_cast(is_bidirectional));
+
+ std::vector bias_table_dims = {num_buckets, num_heads};
+ std::vector sequence_length_dims = {1};
+ std::vector output_dims = {1, num_heads, seq_len, seq_len};
+
+ if (use_float16) {
+ tester.AddInput("bias_table", bias_table_dims, ToFloat16(bias_table));
+ tester.AddInput("query_length", sequence_length_dims, sequence_length);
+ tester.AddInput("key_length", sequence_length_dims, sequence_length);
+ tester.AddOutput("output", output_dims, ToFloat16(output_data));
+ } else {
+ tester.AddInput("bias_table", bias_table_dims, bias_table);
+ tester.AddInput("query_length", sequence_length_dims, sequence_length);
+ tester.AddInput("key_length", sequence_length_dims, sequence_length);
+ tester.AddOutput("output", output_dims, output_data);
+ }
+
+ if (enable_cuda) {
+ std::vector> execution_providers;
+ execution_providers.push_back(DefaultCudaExecutionProvider());
+ tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+ }
+
+ if (enable_cpu) {
+ std::vector> execution_providers;
+ execution_providers.push_back(DefaultCpuExecutionProvider());
+ tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+ }
+ }
+}
+
+TEST(RelativePositionBiasTest, RelativePositionBiasTest_FP32) {
+ int max_distance = 128;
+ int num_buckets = 4;
+ int num_heads = 2;
+ int seq_len = 2;
+ int is_bidirectional = 1;
+
+ // Huggingface bias_table = [[1, 2], [3, 4], [5, 6], [7, 8]].
+ // Save in col-major order in ORT
+ std::vector bias_table = {1.f, 3.f, 5.f, 7.f, 2.f, 4.f, 6.f, 8.f};
+ std::vector sequence_length = {seq_len};
+
+ std::vector output_data = {1.f, 7.f, 3.f, 1.f, 2.f, 8.f, 4.f, 2.f};
+
+ RunRelativePositionBiasTest(bias_table,
+ sequence_length,
+ output_data,
+ max_distance,
+ num_buckets,
+ num_heads,
+ seq_len,
+ is_bidirectional);
+}
+
+TEST(RelativePositionBiasTest, RelativePositionBiasTest_FP16) {
+ int max_distance = 128;
+ int num_buckets = 4;
+ int num_heads = 2;
+ int seq_len = 2;
+ int is_bidirectional = 1;
+
+ // Huggingface bias_table = [[1, 2], [3, 4], [5, 6], [7, 8]].
+ // Save in col-major order in ORT
+ std::vector bias_table = {1.f, 3.f, 5.f, 7.f, 2.f, 4.f, 6.f, 8.f};
+ std::vector sequence_length = {seq_len};
+
+ std::vector output_data = {1.f, 7.f, 3.f, 1.f, 2.f, 8.f, 4.f, 2.f};
+
+ RunRelativePositionBiasTest(bias_table,
+ sequence_length,
+ output_data,
+ max_distance,
+ num_buckets,
+ num_heads,
+ seq_len,
+ is_bidirectional,
+ true);
+}
+
+TEST(RelativePositionBiasTest, RelativePositionBiasTest2_FP16) {
+ int max_distance = 128;
+ int num_buckets = 4;
+ int num_heads = 3;
+ int seq_len = 2;
+ int is_bidirectional = 1;
+
+ // Huggingface bias_table = [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]].
+ // Save in col-major order in ORT
+ std::vector bias_table = {1.f, 4.f, 7.f, 10.f, 2.f, 5.f, 8.f, 11.f, 3.f, 6.f, 9.f, 12.f};
+ std::vector sequence_length = {seq_len};
+
+ std::vector output_data = {1.f, 10.f, 4.f, 1.f, 2.f, 11.f, 5.f, 2.f, 3.f, 12.f, 6.f, 3.f};
+
+ RunRelativePositionBiasTest(bias_table,
+ sequence_length,
+ output_data,
+ max_distance,
+ num_buckets,
+ num_heads,
+ seq_len,
+ is_bidirectional,
+ true);
+}
+
+TEST(RelativePositionBiasTest, RelativePositionBiasTest_FP16_No_Bidirectional) {
+ int max_distance = 128;
+ int num_buckets = 4;
+ int num_heads = 3;
+ int seq_len = 2;
+ int is_bidirectional = 0;
+
+ std::vector bias_table = {1.f, 4.f, 7.f, 10.f, 2.f, 5.f, 8.f, 11.f, 3.f, 6.f, 9.f, 12.f};
+ std::vector sequence_length = {seq_len};
+
+ std::vector output_data = {1.f, 1.f, 4.f, 1.f, 2.f, 2.f, 5.f, 2.f, 3.f, 3.f, 6.f, 3.f};
+
+ RunRelativePositionBiasTest(bias_table,
+ sequence_length,
+ output_data,
+ max_distance,
+ num_buckets,
+ num_heads,
+ seq_len,
+ is_bidirectional,
+ true);
+}
+
+} // namespace test
+} // namespace onnxruntime