mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-09 00:30:53 +00:00
relational attention bias cuda op (#14149)
### Description This cuda op implements the compute_bias() method in T5 Attention including the permutation. note: 1. bias_table needs to be saved in col-major. be careful when implementing fusion script 2. second input(sequence length) is placed on cpu. (using Shape node's output should be good) 3. the first dimension of output is 1, so extra_add_qk in attention should support broadcasting 4. compute_bias() only used in self-attn in t5 TODO: docs change will be applied later ### Motivation and Context It's part of the process of optimizing t5 attention as well as t5 based generation model Co-authored-by: Ubuntu <wy@v100-2.0cdb2e52twzevn1i4fi45bylyg.jx.internal.cloudapp.net>
This commit is contained in:
parent
8e2163018d
commit
5eac2c1f41
10 changed files with 485 additions and 0 deletions
|
|
@ -71,6 +71,7 @@ Do not modify directly.*
|
|||
* <a href="#com.microsoft.QuickGelu">com.microsoft.QuickGelu</a>
|
||||
* <a href="#com.microsoft.Range">com.microsoft.Range</a>
|
||||
* <a href="#com.microsoft.ReduceSumInteger">com.microsoft.ReduceSumInteger</a>
|
||||
* <a href="#com.microsoft.RelativePositionBias">com.microsoft.RelativePositionBias</a>
|
||||
* <a href="#com.microsoft.RemovePadding">com.microsoft.RemovePadding</a>
|
||||
* <a href="#com.microsoft.RestorePadding">com.microsoft.RestorePadding</a>
|
||||
* <a href="#com.microsoft.Rfft">com.microsoft.Rfft</a>
|
||||
|
|
@ -3704,6 +3705,51 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
</dl>
|
||||
|
||||
|
||||
### <a name="com.microsoft.RelativePositionBias"></a><a name="com.microsoft.relativepositionbias">**com.microsoft.RelativePositionBias**</a>
|
||||
|
||||
Compute binned relative position bias for T5 model. ref: https://arxiv.org/abs/1803.02155v2
|
||||
|
||||
#### Version
|
||||
|
||||
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
|
||||
|
||||
#### Attributes
|
||||
|
||||
<dl>
|
||||
<dt><tt>is_bidirectional</tt> : int</dt>
|
||||
<dd>Default value is 0.</dd>
|
||||
<dt><tt>max_distance</tt> : int (required)</dt>
|
||||
<dd>Max distance</dd>
|
||||
</dl>
|
||||
|
||||
#### Inputs
|
||||
|
||||
<dl>
|
||||
<dt><tt>bias_table</tt> : T</dt>
|
||||
<dd>2D input tensor with shape (num_buckets, num_heads), COL-major(See UT for example)</dd>
|
||||
<dt><tt>query_length</tt> : U</dt>
|
||||
<dd>The length of query. Self Attention requires query_length = key_length</dd>
|
||||
<dt><tt>key_length</tt> : U</dt>
|
||||
<dd>The length of key.</dd>
|
||||
</dl>
|
||||
|
||||
#### Outputs
|
||||
|
||||
<dl>
|
||||
<dt><tt>output</tt> : T</dt>
|
||||
<dd>4D output tensor with shape (1, num_heads, sequence_length, sequence_length)</dd>
|
||||
</dl>
|
||||
|
||||
#### Type Constraints
|
||||
|
||||
<dl>
|
||||
<dt><tt>T</tt> : tensor(float), tensor(float16)</dt>
|
||||
<dd>Constrain input and output types to float or half tensors.</dd>
|
||||
<dt><tt>U</tt> : tensor(int64)</dt>
|
||||
<dd>Constrain sequence_length to int tensors.</dd>
|
||||
</dl>
|
||||
|
||||
|
||||
### <a name="com.microsoft.RemovePadding"></a><a name="com.microsoft.removepadding">**com.microsoft.RemovePadding**</a>
|
||||
|
||||
Compress transformer input by removing paddings. It assumes padding is on the right side of sequence.
|
||||
|
|
|
|||
|
|
@ -796,6 +796,7 @@ Do not modify directly.*
|
|||
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|1+|**T1** = tensor(float16)<br/> **T2** = tensor(int8), tensor(uint8)|
|
||||
|QuantizeWithOrder|*in* input:**F**<br> *in* scale_input:**S**<br> *out* output:**Q**|1+|**F** = tensor(float), tensor(float16)<br/> **Q** = tensor(int8)<br/> **S** = tensor(float)|
|
||||
|QuickGelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|RelativePositionBias|*in* bias_table:**T**<br> *in* query_length:**U**<br> *in* key_length:**U**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|RemovePadding|*in* input:**T**<br> *in* sequence_token_count:**M**<br> *out* output:**T**<br> *out* token_offset:**M**<br> *out* cumulated_seq_len:**M**<br> *out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|RestorePadding|*in* input:**T**<br> *in* token_offset:**M**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|Rfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|
|
|
|||
72
onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc
Normal file
72
onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "relative_attn_bias.h"
|
||||
#include "relative_attn_bias_impl.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
#define REGISTER_KERNEL_TYPED(T) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
RelativePositionBias , \
|
||||
kMSDomain, \
|
||||
1, \
|
||||
T, \
|
||||
kCudaExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()) \
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 1) \
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 2) \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
RelPosAttnBias<T>);
|
||||
|
||||
REGISTER_KERNEL_TYPED(float)
|
||||
REGISTER_KERNEL_TYPED(MLFloat16)
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
|
||||
template <typename T>
|
||||
RelPosAttnBias<T>::RelPosAttnBias(const OpKernelInfo& info) : CudaKernel(info) {
|
||||
is_bidirectional_ = info.GetAttrOrDefault<int64_t>("is_bidirectional", 0) == 1;
|
||||
|
||||
int64_t max_distance = 0;
|
||||
ORT_ENFORCE(info.GetAttr("max_distance", &max_distance).IsOK() && max_distance > 0);
|
||||
max_distance_ = static_cast<int>(max_distance);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status RelPosAttnBias<T>::ComputeInternal(OpKernelContext* context) const {
|
||||
const Tensor* bias_table = context->Input<Tensor>(0);
|
||||
const Tensor* query_length = context->Input<Tensor>(1);
|
||||
const Tensor* key_length = context->Input<Tensor>(2);
|
||||
|
||||
const auto& bias_table_dims = bias_table->Shape().GetDims();
|
||||
const int64_t num_buckets = bias_table_dims[0];
|
||||
const int64_t num_heads = bias_table_dims[1];
|
||||
|
||||
const int64_t query_len = *query_length->Data<int64_t>();
|
||||
const int64_t key_len = *key_length->Data<int64_t>();
|
||||
|
||||
if (query_len != key_len) {
|
||||
ORT_THROW("Relatvie position bias currently only support query length equal to key length in Self Attention.");
|
||||
}
|
||||
|
||||
Tensor* output = context->Output(0, {1, num_heads, query_len, key_len});
|
||||
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
|
||||
return LaunchRelPosAttnBiasKernel<CudaT>(Stream(context),
|
||||
reinterpret_cast<CudaT*>(output->template MutableData<T>()),
|
||||
reinterpret_cast<const CudaT*>(bias_table->template Data<T>()),
|
||||
static_cast<int>(num_heads),
|
||||
static_cast<int>(query_len),
|
||||
static_cast<int>(num_buckets),
|
||||
max_distance_,
|
||||
is_bidirectional_);
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
27
onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.h
Normal file
27
onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.h
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include "core/common/common.h"
|
||||
#include "core/providers/cuda/cuda_kernel.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
using namespace onnxruntime::cuda;
|
||||
|
||||
template <typename T>
|
||||
class RelPosAttnBias final : public CudaKernel {
|
||||
public:
|
||||
RelPosAttnBias(const OpKernelInfo& op_kernel_info);
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override;
|
||||
|
||||
private:
|
||||
int max_distance_;
|
||||
bool is_bidirectional_;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
125
onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu
Normal file
125
onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
/*
|
||||
Copyright (c) Microsoft Corporation.
|
||||
Licensed under the MIT License.
|
||||
*/
|
||||
/*
|
||||
* Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "core/providers/cuda/cu_inc/common.cuh"
|
||||
#include "contrib_ops/cuda/bert/relative_attn_bias_impl.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
using namespace onnxruntime::cuda;
|
||||
|
||||
template<typename T>
|
||||
__global__ void buildRelativeAttentionBias(T* relative_attention_bias,
|
||||
const T* relative_attention_bias_table,
|
||||
const int head_num,
|
||||
const int seq_len,
|
||||
const int num_bucket,
|
||||
const bool is_bidirectional,
|
||||
const int max_distance) {
|
||||
const int head_id = blockIdx.x;
|
||||
for (int seq_id = threadIdx.x; seq_id < seq_len * seq_len; seq_id += blockDim.x) {
|
||||
int row_id = seq_id / seq_len;
|
||||
int col_id = seq_id % seq_len;
|
||||
|
||||
int relative_position = col_id - row_id;
|
||||
|
||||
int relative_buckets = 0;
|
||||
int tmp_num_bucket = num_bucket;
|
||||
|
||||
if (is_bidirectional) {
|
||||
tmp_num_bucket /= 2;
|
||||
if (relative_position > 0) {
|
||||
relative_buckets += tmp_num_bucket;
|
||||
} else {
|
||||
relative_position *= -1;
|
||||
}
|
||||
} else {
|
||||
if (relative_position > 0) {
|
||||
relative_position = 0;
|
||||
} else {
|
||||
relative_position *= -1;
|
||||
}
|
||||
}
|
||||
|
||||
int max_exact = tmp_num_bucket / 2;
|
||||
bool is_small = relative_position < max_exact;
|
||||
|
||||
int relative_position_if_large =
|
||||
max_exact
|
||||
+ (int)(logf(relative_position * 1.0f / max_exact) / logf((float)max_distance / max_exact)
|
||||
* (tmp_num_bucket - max_exact));
|
||||
|
||||
relative_position_if_large = min(relative_position_if_large, tmp_num_bucket - 1);
|
||||
|
||||
relative_buckets += is_small ? relative_position : relative_position_if_large;
|
||||
|
||||
relative_attention_bias[head_id * seq_len * seq_len + seq_id] =
|
||||
relative_attention_bias_table[head_id * num_bucket + relative_buckets];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status LaunchRelPosAttnBiasKernel(
|
||||
cudaStream_t stream,
|
||||
T* output,
|
||||
const T* bias_table,
|
||||
const int num_heads,
|
||||
const int seq_len,
|
||||
const int num_bucket,
|
||||
const int max_distance,
|
||||
const bool is_bidirectional)
|
||||
{
|
||||
dim3 grid(num_heads);
|
||||
dim3 block(256);
|
||||
|
||||
buildRelativeAttentionBias<<<grid, block, 0, stream>>>(output,
|
||||
bias_table,
|
||||
num_heads,
|
||||
seq_len,
|
||||
num_bucket,
|
||||
is_bidirectional,
|
||||
max_distance);
|
||||
|
||||
return CUDA_CALL(cudaGetLastError());
|
||||
}
|
||||
|
||||
template Status LaunchRelPosAttnBiasKernel<float>(cudaStream_t stream,
|
||||
float* output,
|
||||
const float* bias_table,
|
||||
const int num_heads,
|
||||
const int seq_len,
|
||||
const int num_bucket,
|
||||
const int max_distance,
|
||||
const bool is_bidirectional);
|
||||
|
||||
template Status LaunchRelPosAttnBiasKernel<half>(cudaStream_t stream,
|
||||
half* output,
|
||||
const half* bias_table,
|
||||
const int num_heads,
|
||||
const int seq_len,
|
||||
const int num_bucket,
|
||||
const int max_distance,
|
||||
const bool is_bidirectional);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
26
onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.h
Normal file
26
onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.h
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include "core/common/common.h"
|
||||
#include "core/providers/cuda/shared_inc/cuda_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
template <typename T>
|
||||
Status LaunchRelPosAttnBiasKernel(
|
||||
cudaStream_t stream,
|
||||
T* output,
|
||||
const T* bias_table,
|
||||
const int num_heads,
|
||||
const int seq_len,
|
||||
const int num_bucket,
|
||||
const int max_distance,
|
||||
const bool is_bidirectional
|
||||
);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -28,6 +28,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FusedMatMul);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, FusedMatMul);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, FusedMatMul);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RelativePositionBias);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RelativePositionBias);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RemovePadding);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RemovePadding);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RestorePadding);
|
||||
|
|
@ -151,6 +153,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FusedMatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, FusedMatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, FusedMatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RelativePositionBias)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RelativePositionBias)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RemovePadding)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RemovePadding)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RestorePadding)>,
|
||||
|
|
|
|||
|
|
@ -436,6 +436,29 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
|
|||
schema.BuildFunction(functionProto);
|
||||
return true;
|
||||
}));
|
||||
|
||||
ONNX_MS_OPERATOR_SET_SCHEMA(
|
||||
RelativePositionBias, 1,
|
||||
OpSchema()
|
||||
.SetDoc("Compute binned relative position bias for T5 model. ref: https://arxiv.org/abs/1803.02155v2")
|
||||
.Attr("max_distance", "Max distance", AttributeProto::INT)
|
||||
.Attr("is_bidirectional", "Default value is 0.", AttributeProto::INT, static_cast<int64_t>(0))
|
||||
.Input(0, "bias_table", "2D input tensor with shape (num_buckets, num_heads), COL-major(See UT for example)", "T")
|
||||
.Input(1, "query_length", "The length of query. Self Attention requires query_length = key_length", "U")
|
||||
.Input(2, "key_length", "The length of key.", "U")
|
||||
.Output(0, "output", "4D output tensor with shape (1, num_heads, sequence_length, sequence_length)", "T")
|
||||
.TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float or half tensors.")
|
||||
.TypeConstraint("U", {"tensor(int64)"}, "Constrain sequence_length to int tensors.")
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
propagateElemTypeFromInputToOutput(ctx, 0, 0);
|
||||
auto& bias_table_shape = getInputShape(ctx, 0);
|
||||
TensorShapeProto output_shape;
|
||||
output_shape.add_dim()->set_dim_value(1);
|
||||
*output_shape.add_dim() = bias_table_shape.dim(1);
|
||||
output_shape.add_dim();
|
||||
output_shape.add_dim();
|
||||
updateOutputShape(ctx, 0, output_shape);
|
||||
}));
|
||||
|
||||
ONNX_MS_OPERATOR_SET_SCHEMA(
|
||||
SkipLayerNormalization, 1,
|
||||
|
|
|
|||
|
|
@ -78,6 +78,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MaxpoolWithMask);
|
|||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MurmurHash3);
|
||||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, NGramRepeatBlock);
|
||||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Pad);
|
||||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RelativePositionBias);
|
||||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RemovePadding);
|
||||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RestorePadding);
|
||||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Rfft);
|
||||
|
|
@ -165,6 +166,7 @@ class OpSet_Microsoft_ver1 {
|
|||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Pad)>());
|
||||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QAttention)>());
|
||||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QEmbedLayerNormalization)>());
|
||||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RelativePositionBias)>());
|
||||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RemovePadding)>());
|
||||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RestorePadding)>());
|
||||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Rfft)>());
|
||||
|
|
|
|||
159
onnxruntime/test/contrib_ops/relative_attention_bias_test.cc
Normal file
159
onnxruntime/test/contrib_ops/relative_attention_bias_test.cc
Normal file
|
|
@ -0,0 +1,159 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test/common/tensor_op_test_utils.h"
|
||||
#include "test/common/cuda_op_test_utils.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
static void RunRelativePositionBiasTest(
|
||||
const std::vector<float>& bias_table, // Shape = [num_buckets, num_heads]
|
||||
const std::vector<int64_t>& sequence_length, // Shape = [1]
|
||||
const std::vector<float>& output_data, // Shape = [1, num_heads, sequence_length, sequence_length]
|
||||
int max_distance,
|
||||
int num_buckets,
|
||||
int num_heads,
|
||||
int seq_len,
|
||||
int is_bidirectional,
|
||||
bool use_float16 = false) {
|
||||
int min_cuda_architecture = use_float16 ? 530 : 0;
|
||||
|
||||
bool enable_cuda = HasCudaEnvironment(min_cuda_architecture);
|
||||
bool enable_cpu = false;
|
||||
if (enable_cpu || enable_cuda) {
|
||||
OpTester tester("RelativePositionBias", 1, onnxruntime::kMSDomain);
|
||||
tester.AddAttribute<int64_t>("max_distance", static_cast<int64_t>(max_distance));
|
||||
tester.AddAttribute<int64_t>("is_bidirectional", static_cast<int64_t>(is_bidirectional));
|
||||
|
||||
std::vector<int64_t> bias_table_dims = {num_buckets, num_heads};
|
||||
std::vector<int64_t> sequence_length_dims = {1};
|
||||
std::vector<int64_t> output_dims = {1, num_heads, seq_len, seq_len};
|
||||
|
||||
if (use_float16) {
|
||||
tester.AddInput<MLFloat16>("bias_table", bias_table_dims, ToFloat16(bias_table));
|
||||
tester.AddInput<int64_t>("query_length", sequence_length_dims, sequence_length);
|
||||
tester.AddInput<int64_t>("key_length", sequence_length_dims, sequence_length);
|
||||
tester.AddOutput<MLFloat16>("output", output_dims, ToFloat16(output_data));
|
||||
} else {
|
||||
tester.AddInput<float>("bias_table", bias_table_dims, bias_table);
|
||||
tester.AddInput<int64_t>("query_length", sequence_length_dims, sequence_length);
|
||||
tester.AddInput<int64_t>("key_length", sequence_length_dims, sequence_length);
|
||||
tester.AddOutput<float>("output", output_dims, output_data);
|
||||
}
|
||||
|
||||
if (enable_cuda) {
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
execution_providers.push_back(DefaultCudaExecutionProvider());
|
||||
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
|
||||
}
|
||||
|
||||
if (enable_cpu) {
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
execution_providers.push_back(DefaultCpuExecutionProvider());
|
||||
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(RelativePositionBiasTest, RelativePositionBiasTest_FP32) {
|
||||
int max_distance = 128;
|
||||
int num_buckets = 4;
|
||||
int num_heads = 2;
|
||||
int seq_len = 2;
|
||||
int is_bidirectional = 1;
|
||||
|
||||
// Huggingface bias_table = [[1, 2], [3, 4], [5, 6], [7, 8]].
|
||||
// Save in col-major order in ORT
|
||||
std::vector<float> bias_table = {1.f, 3.f, 5.f, 7.f, 2.f, 4.f, 6.f, 8.f};
|
||||
std::vector<int64_t> sequence_length = {seq_len};
|
||||
|
||||
std::vector<float> output_data = {1.f, 7.f, 3.f, 1.f, 2.f, 8.f, 4.f, 2.f};
|
||||
|
||||
RunRelativePositionBiasTest(bias_table,
|
||||
sequence_length,
|
||||
output_data,
|
||||
max_distance,
|
||||
num_buckets,
|
||||
num_heads,
|
||||
seq_len,
|
||||
is_bidirectional);
|
||||
}
|
||||
|
||||
TEST(RelativePositionBiasTest, RelativePositionBiasTest_FP16) {
|
||||
int max_distance = 128;
|
||||
int num_buckets = 4;
|
||||
int num_heads = 2;
|
||||
int seq_len = 2;
|
||||
int is_bidirectional = 1;
|
||||
|
||||
// Huggingface bias_table = [[1, 2], [3, 4], [5, 6], [7, 8]].
|
||||
// Save in col-major order in ORT
|
||||
std::vector<float> bias_table = {1.f, 3.f, 5.f, 7.f, 2.f, 4.f, 6.f, 8.f};
|
||||
std::vector<int64_t> sequence_length = {seq_len};
|
||||
|
||||
std::vector<float> output_data = {1.f, 7.f, 3.f, 1.f, 2.f, 8.f, 4.f, 2.f};
|
||||
|
||||
RunRelativePositionBiasTest(bias_table,
|
||||
sequence_length,
|
||||
output_data,
|
||||
max_distance,
|
||||
num_buckets,
|
||||
num_heads,
|
||||
seq_len,
|
||||
is_bidirectional,
|
||||
true);
|
||||
}
|
||||
|
||||
TEST(RelativePositionBiasTest, RelativePositionBiasTest2_FP16) {
|
||||
int max_distance = 128;
|
||||
int num_buckets = 4;
|
||||
int num_heads = 3;
|
||||
int seq_len = 2;
|
||||
int is_bidirectional = 1;
|
||||
|
||||
// Huggingface bias_table = [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]].
|
||||
// Save in col-major order in ORT
|
||||
std::vector<float> bias_table = {1.f, 4.f, 7.f, 10.f, 2.f, 5.f, 8.f, 11.f, 3.f, 6.f, 9.f, 12.f};
|
||||
std::vector<int64_t> sequence_length = {seq_len};
|
||||
|
||||
std::vector<float> output_data = {1.f, 10.f, 4.f, 1.f, 2.f, 11.f, 5.f, 2.f, 3.f, 12.f, 6.f, 3.f};
|
||||
|
||||
RunRelativePositionBiasTest(bias_table,
|
||||
sequence_length,
|
||||
output_data,
|
||||
max_distance,
|
||||
num_buckets,
|
||||
num_heads,
|
||||
seq_len,
|
||||
is_bidirectional,
|
||||
true);
|
||||
}
|
||||
|
||||
TEST(RelativePositionBiasTest, RelativePositionBiasTest_FP16_No_Bidirectional) {
|
||||
int max_distance = 128;
|
||||
int num_buckets = 4;
|
||||
int num_heads = 3;
|
||||
int seq_len = 2;
|
||||
int is_bidirectional = 0;
|
||||
|
||||
std::vector<float> bias_table = {1.f, 4.f, 7.f, 10.f, 2.f, 5.f, 8.f, 11.f, 3.f, 6.f, 9.f, 12.f};
|
||||
std::vector<int64_t> sequence_length = {seq_len};
|
||||
|
||||
std::vector<float> output_data = {1.f, 1.f, 4.f, 1.f, 2.f, 2.f, 5.f, 2.f, 3.f, 3.f, 6.f, 3.f};
|
||||
|
||||
RunRelativePositionBiasTest(bias_table,
|
||||
sequence_length,
|
||||
output_data,
|
||||
max_distance,
|
||||
num_buckets,
|
||||
num_heads,
|
||||
seq_len,
|
||||
is_bidirectional,
|
||||
true);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue