relational attention bias cuda op (#14149)

### Description

This cuda op implements the compute_bias() method in T5 Attention
including the permutation.

note:
1. bias_table needs to be saved in col-major. be careful when
implementing fusion script
2. second input(sequence length) is placed on cpu. (using Shape node's
output should be good)
3. the first dimension of output is 1, so extra_add_qk in attention
should support broadcasting
4. compute_bias() only used in self-attn in t5

TODO: docs change will be applied later

### Motivation and Context
It's part of the process of optimizing t5 attention as well as t5 based
generation model

Co-authored-by: Ubuntu <wy@v100-2.0cdb2e52twzevn1i4fi45bylyg.jx.internal.cloudapp.net>
This commit is contained in:
Ye Wang 2023-01-06 17:32:58 -08:00 committed by GitHub
parent 8e2163018d
commit 5eac2c1f41
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 485 additions and 0 deletions

View file

@ -71,6 +71,7 @@ Do not modify directly.*
* <a href="#com.microsoft.QuickGelu">com.microsoft.QuickGelu</a>
* <a href="#com.microsoft.Range">com.microsoft.Range</a>
* <a href="#com.microsoft.ReduceSumInteger">com.microsoft.ReduceSumInteger</a>
* <a href="#com.microsoft.RelativePositionBias">com.microsoft.RelativePositionBias</a>
* <a href="#com.microsoft.RemovePadding">com.microsoft.RemovePadding</a>
* <a href="#com.microsoft.RestorePadding">com.microsoft.RestorePadding</a>
* <a href="#com.microsoft.Rfft">com.microsoft.Rfft</a>
@ -3704,6 +3705,51 @@ This version of the operator has been available since version 1 of the 'com.micr
</dl>
### <a name="com.microsoft.RelativePositionBias"></a><a name="com.microsoft.relativepositionbias">**com.microsoft.RelativePositionBias**</a>
Compute binned relative position bias for T5 model. ref: https://arxiv.org/abs/1803.02155v2
#### Version
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
#### Attributes
<dl>
<dt><tt>is_bidirectional</tt> : int</dt>
<dd>Default value is 0.</dd>
<dt><tt>max_distance</tt> : int (required)</dt>
<dd>Max distance</dd>
</dl>
#### Inputs
<dl>
<dt><tt>bias_table</tt> : T</dt>
<dd>2D input tensor with shape (num_buckets, num_heads), COL-major(See UT for example)</dd>
<dt><tt>query_length</tt> : U</dt>
<dd>The length of query. Self Attention requires query_length = key_length</dd>
<dt><tt>key_length</tt> : U</dt>
<dd>The length of key.</dd>
</dl>
#### Outputs
<dl>
<dt><tt>output</tt> : T</dt>
<dd>4D output tensor with shape (1, num_heads, sequence_length, sequence_length)</dd>
</dl>
#### Type Constraints
<dl>
<dt><tt>T</tt> : tensor(float), tensor(float16)</dt>
<dd>Constrain input and output types to float or half tensors.</dd>
<dt><tt>U</tt> : tensor(int64)</dt>
<dd>Constrain sequence_length to int tensors.</dd>
</dl>
### <a name="com.microsoft.RemovePadding"></a><a name="com.microsoft.removepadding">**com.microsoft.RemovePadding**</a>
Compress transformer input by removing paddings. It assumes padding is on the right side of sequence.

View file

@ -796,6 +796,7 @@ Do not modify directly.*
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|1+|**T1** = tensor(float16)<br/> **T2** = tensor(int8), tensor(uint8)|
|QuantizeWithOrder|*in* input:**F**<br> *in* scale_input:**S**<br> *out* output:**Q**|1+|**F** = tensor(float), tensor(float16)<br/> **Q** = tensor(int8)<br/> **S** = tensor(float)|
|QuickGelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|RelativePositionBias|*in* bias_table:**T**<br> *in* query_length:**U**<br> *in* key_length:**U**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|RemovePadding|*in* input:**T**<br> *in* sequence_token_count:**M**<br> *out* output:**T**<br> *out* token_offset:**M**<br> *out* cumulated_seq_len:**M**<br> *out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)|
|RestorePadding|*in* input:**T**<br> *in* token_offset:**M**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|Rfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|

View file

@ -0,0 +1,72 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/cuda/cuda_common.h"
#include "relative_attn_bias.h"
#include "relative_attn_bias_impl.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
RelativePositionBias , \
kMSDomain, \
1, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.InputMemoryType(OrtMemTypeCPUInput, 1) \
.InputMemoryType(OrtMemTypeCPUInput, 2) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
RelPosAttnBias<T>);
REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
using namespace ONNX_NAMESPACE;
template <typename T>
RelPosAttnBias<T>::RelPosAttnBias(const OpKernelInfo& info) : CudaKernel(info) {
is_bidirectional_ = info.GetAttrOrDefault<int64_t>("is_bidirectional", 0) == 1;
int64_t max_distance = 0;
ORT_ENFORCE(info.GetAttr("max_distance", &max_distance).IsOK() && max_distance > 0);
max_distance_ = static_cast<int>(max_distance);
}
template <typename T>
Status RelPosAttnBias<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* bias_table = context->Input<Tensor>(0);
const Tensor* query_length = context->Input<Tensor>(1);
const Tensor* key_length = context->Input<Tensor>(2);
const auto& bias_table_dims = bias_table->Shape().GetDims();
const int64_t num_buckets = bias_table_dims[0];
const int64_t num_heads = bias_table_dims[1];
const int64_t query_len = *query_length->Data<int64_t>();
const int64_t key_len = *key_length->Data<int64_t>();
if (query_len != key_len) {
ORT_THROW("Relatvie position bias currently only support query length equal to key length in Self Attention.");
}
Tensor* output = context->Output(0, {1, num_heads, query_len, key_len});
typedef typename ToCudaType<T>::MappedType CudaT;
return LaunchRelPosAttnBiasKernel<CudaT>(Stream(context),
reinterpret_cast<CudaT*>(output->template MutableData<T>()),
reinterpret_cast<const CudaT*>(bias_table->template Data<T>()),
static_cast<int>(num_heads),
static_cast<int>(query_len),
static_cast<int>(num_buckets),
max_distance_,
is_bidirectional_);
}
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,27 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/providers/cuda/cuda_kernel.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
using namespace onnxruntime::cuda;
template <typename T>
class RelPosAttnBias final : public CudaKernel {
public:
RelPosAttnBias(const OpKernelInfo& op_kernel_info);
Status ComputeInternal(OpKernelContext* ctx) const override;
private:
int max_distance_;
bool is_bidirectional_;
};
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,125 @@
/*
Copyright (c) Microsoft Corporation.
Licensed under the MIT License.
*/
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "core/providers/cuda/cu_inc/common.cuh"
#include "contrib_ops/cuda/bert/relative_attn_bias_impl.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
using namespace onnxruntime::cuda;
template<typename T>
__global__ void buildRelativeAttentionBias(T* relative_attention_bias,
const T* relative_attention_bias_table,
const int head_num,
const int seq_len,
const int num_bucket,
const bool is_bidirectional,
const int max_distance) {
const int head_id = blockIdx.x;
for (int seq_id = threadIdx.x; seq_id < seq_len * seq_len; seq_id += blockDim.x) {
int row_id = seq_id / seq_len;
int col_id = seq_id % seq_len;
int relative_position = col_id - row_id;
int relative_buckets = 0;
int tmp_num_bucket = num_bucket;
if (is_bidirectional) {
tmp_num_bucket /= 2;
if (relative_position > 0) {
relative_buckets += tmp_num_bucket;
} else {
relative_position *= -1;
}
} else {
if (relative_position > 0) {
relative_position = 0;
} else {
relative_position *= -1;
}
}
int max_exact = tmp_num_bucket / 2;
bool is_small = relative_position < max_exact;
int relative_position_if_large =
max_exact
+ (int)(logf(relative_position * 1.0f / max_exact) / logf((float)max_distance / max_exact)
* (tmp_num_bucket - max_exact));
relative_position_if_large = min(relative_position_if_large, tmp_num_bucket - 1);
relative_buckets += is_small ? relative_position : relative_position_if_large;
relative_attention_bias[head_id * seq_len * seq_len + seq_id] =
relative_attention_bias_table[head_id * num_bucket + relative_buckets];
}
}
template <typename T>
Status LaunchRelPosAttnBiasKernel(
cudaStream_t stream,
T* output,
const T* bias_table,
const int num_heads,
const int seq_len,
const int num_bucket,
const int max_distance,
const bool is_bidirectional)
{
dim3 grid(num_heads);
dim3 block(256);
buildRelativeAttentionBias<<<grid, block, 0, stream>>>(output,
bias_table,
num_heads,
seq_len,
num_bucket,
is_bidirectional,
max_distance);
return CUDA_CALL(cudaGetLastError());
}
template Status LaunchRelPosAttnBiasKernel<float>(cudaStream_t stream,
float* output,
const float* bias_table,
const int num_heads,
const int seq_len,
const int num_bucket,
const int max_distance,
const bool is_bidirectional);
template Status LaunchRelPosAttnBiasKernel<half>(cudaStream_t stream,
half* output,
const half* bias_table,
const int num_heads,
const int seq_len,
const int num_bucket,
const int max_distance,
const bool is_bidirectional);
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,26 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/providers/cuda/shared_inc/cuda_utils.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
template <typename T>
Status LaunchRelPosAttnBiasKernel(
cudaStream_t stream,
T* output,
const T* bias_table,
const int num_heads,
const int seq_len,
const int num_bucket,
const int max_distance,
const bool is_bidirectional
);
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -28,6 +28,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FusedMatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, FusedMatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, FusedMatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RelativePositionBias);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RelativePositionBias);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RemovePadding);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RemovePadding);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RestorePadding);
@ -151,6 +153,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FusedMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, FusedMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, FusedMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RelativePositionBias)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RelativePositionBias)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RemovePadding)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RemovePadding)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RestorePadding)>,

View file

@ -436,6 +436,29 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
schema.BuildFunction(functionProto);
return true;
}));
ONNX_MS_OPERATOR_SET_SCHEMA(
RelativePositionBias, 1,
OpSchema()
.SetDoc("Compute binned relative position bias for T5 model. ref: https://arxiv.org/abs/1803.02155v2")
.Attr("max_distance", "Max distance", AttributeProto::INT)
.Attr("is_bidirectional", "Default value is 0.", AttributeProto::INT, static_cast<int64_t>(0))
.Input(0, "bias_table", "2D input tensor with shape (num_buckets, num_heads), COL-major(See UT for example)", "T")
.Input(1, "query_length", "The length of query. Self Attention requires query_length = key_length", "U")
.Input(2, "key_length", "The length of key.", "U")
.Output(0, "output", "4D output tensor with shape (1, num_heads, sequence_length, sequence_length)", "T")
.TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float or half tensors.")
.TypeConstraint("U", {"tensor(int64)"}, "Constrain sequence_length to int tensors.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
auto& bias_table_shape = getInputShape(ctx, 0);
TensorShapeProto output_shape;
output_shape.add_dim()->set_dim_value(1);
*output_shape.add_dim() = bias_table_shape.dim(1);
output_shape.add_dim();
output_shape.add_dim();
updateOutputShape(ctx, 0, output_shape);
}));
ONNX_MS_OPERATOR_SET_SCHEMA(
SkipLayerNormalization, 1,

View file

@ -78,6 +78,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MaxpoolWithMask);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MurmurHash3);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, NGramRepeatBlock);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Pad);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RelativePositionBias);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RemovePadding);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RestorePadding);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Rfft);
@ -165,6 +166,7 @@ class OpSet_Microsoft_ver1 {
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Pad)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QAttention)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QEmbedLayerNormalization)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RelativePositionBias)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RemovePadding)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RestorePadding)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Rfft)>());

View file

@ -0,0 +1,159 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "gtest/gtest.h"
#include "test/common/tensor_op_test_utils.h"
#include "test/common/cuda_op_test_utils.h"
#include "test/providers/provider_test_utils.h"
namespace onnxruntime {
namespace test {
static void RunRelativePositionBiasTest(
const std::vector<float>& bias_table, // Shape = [num_buckets, num_heads]
const std::vector<int64_t>& sequence_length, // Shape = [1]
const std::vector<float>& output_data, // Shape = [1, num_heads, sequence_length, sequence_length]
int max_distance,
int num_buckets,
int num_heads,
int seq_len,
int is_bidirectional,
bool use_float16 = false) {
int min_cuda_architecture = use_float16 ? 530 : 0;
bool enable_cuda = HasCudaEnvironment(min_cuda_architecture);
bool enable_cpu = false;
if (enable_cpu || enable_cuda) {
OpTester tester("RelativePositionBias", 1, onnxruntime::kMSDomain);
tester.AddAttribute<int64_t>("max_distance", static_cast<int64_t>(max_distance));
tester.AddAttribute<int64_t>("is_bidirectional", static_cast<int64_t>(is_bidirectional));
std::vector<int64_t> bias_table_dims = {num_buckets, num_heads};
std::vector<int64_t> sequence_length_dims = {1};
std::vector<int64_t> output_dims = {1, num_heads, seq_len, seq_len};
if (use_float16) {
tester.AddInput<MLFloat16>("bias_table", bias_table_dims, ToFloat16(bias_table));
tester.AddInput<int64_t>("query_length", sequence_length_dims, sequence_length);
tester.AddInput<int64_t>("key_length", sequence_length_dims, sequence_length);
tester.AddOutput<MLFloat16>("output", output_dims, ToFloat16(output_data));
} else {
tester.AddInput<float>("bias_table", bias_table_dims, bias_table);
tester.AddInput<int64_t>("query_length", sequence_length_dims, sequence_length);
tester.AddInput<int64_t>("key_length", sequence_length_dims, sequence_length);
tester.AddOutput<float>("output", output_dims, output_data);
}
if (enable_cuda) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
if (enable_cpu) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}
}
TEST(RelativePositionBiasTest, RelativePositionBiasTest_FP32) {
int max_distance = 128;
int num_buckets = 4;
int num_heads = 2;
int seq_len = 2;
int is_bidirectional = 1;
// Huggingface bias_table = [[1, 2], [3, 4], [5, 6], [7, 8]].
// Save in col-major order in ORT
std::vector<float> bias_table = {1.f, 3.f, 5.f, 7.f, 2.f, 4.f, 6.f, 8.f};
std::vector<int64_t> sequence_length = {seq_len};
std::vector<float> output_data = {1.f, 7.f, 3.f, 1.f, 2.f, 8.f, 4.f, 2.f};
RunRelativePositionBiasTest(bias_table,
sequence_length,
output_data,
max_distance,
num_buckets,
num_heads,
seq_len,
is_bidirectional);
}
TEST(RelativePositionBiasTest, RelativePositionBiasTest_FP16) {
int max_distance = 128;
int num_buckets = 4;
int num_heads = 2;
int seq_len = 2;
int is_bidirectional = 1;
// Huggingface bias_table = [[1, 2], [3, 4], [5, 6], [7, 8]].
// Save in col-major order in ORT
std::vector<float> bias_table = {1.f, 3.f, 5.f, 7.f, 2.f, 4.f, 6.f, 8.f};
std::vector<int64_t> sequence_length = {seq_len};
std::vector<float> output_data = {1.f, 7.f, 3.f, 1.f, 2.f, 8.f, 4.f, 2.f};
RunRelativePositionBiasTest(bias_table,
sequence_length,
output_data,
max_distance,
num_buckets,
num_heads,
seq_len,
is_bidirectional,
true);
}
TEST(RelativePositionBiasTest, RelativePositionBiasTest2_FP16) {
int max_distance = 128;
int num_buckets = 4;
int num_heads = 3;
int seq_len = 2;
int is_bidirectional = 1;
// Huggingface bias_table = [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]].
// Save in col-major order in ORT
std::vector<float> bias_table = {1.f, 4.f, 7.f, 10.f, 2.f, 5.f, 8.f, 11.f, 3.f, 6.f, 9.f, 12.f};
std::vector<int64_t> sequence_length = {seq_len};
std::vector<float> output_data = {1.f, 10.f, 4.f, 1.f, 2.f, 11.f, 5.f, 2.f, 3.f, 12.f, 6.f, 3.f};
RunRelativePositionBiasTest(bias_table,
sequence_length,
output_data,
max_distance,
num_buckets,
num_heads,
seq_len,
is_bidirectional,
true);
}
TEST(RelativePositionBiasTest, RelativePositionBiasTest_FP16_No_Bidirectional) {
int max_distance = 128;
int num_buckets = 4;
int num_heads = 3;
int seq_len = 2;
int is_bidirectional = 0;
std::vector<float> bias_table = {1.f, 4.f, 7.f, 10.f, 2.f, 5.f, 8.f, 11.f, 3.f, 6.f, 9.f, 12.f};
std::vector<int64_t> sequence_length = {seq_len};
std::vector<float> output_data = {1.f, 1.f, 4.f, 1.f, 2.f, 2.f, 5.f, 2.f, 3.f, 3.f, 6.f, 3.f};
RunRelativePositionBiasTest(bias_table,
sequence_length,
output_data,
max_distance,
num_buckets,
num_heads,
seq_len,
is_bidirectional,
true);
}
} // namespace test
} // namespace onnxruntime