[CUDA] RelativePositionBias supports input with padding removed (#16923)

update RelativePositionBias to support input with padding removed.
- [x] add bias transpose kernel
- [x] add test
- [x] update operator document
This commit is contained in:
Tianlei Wu 2023-08-01 16:39:09 -07:00 committed by GitHub
parent afac67bcc3
commit 50bf310dea
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 254 additions and 37 deletions

View file

@ -1886,11 +1886,11 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Number of attention heads</dd>
</dl>
#### Inputs
#### Inputs (6 - 7)
<dl>
<dt><tt>query_layer</tt> : T</dt>
<dd>tensor with shape (batch_size, seq_len, num_heads x head_size)</dd>
<dd>tensor with shape (batch_size, seq_len, num_heads x head_size) or (token_count, num_heads x head_size)</dd>
<dt><tt>query_bias</tt> : T</dt>
<dd>1-d tensor with shape (num_heads x head_size)</dd>
<dt><tt>rel_pos</tt> : T</dt>
@ -1901,6 +1901,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>bias for the gated_ur_linear, shape (D)</dd>
<dt><tt>eco_a</tt> : T</dt>
<dd>tensor of shape (1, num_heads, 1, 1)</dd>
<dt><tt>token_offset</tt> (optional) : M</dt>
<dd>offset of each token with shape (batch_size, seq_len)</dd>
</dl>
#### Outputs
@ -1915,6 +1917,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dl>
<dt><tt>T</tt> : tensor(float), tensor(float16)</dt>
<dd>Constrain input and output types to float tensors.</dd>
<dt><tt>M</tt> : tensor(int32)</dt>
<dd>Constrain token_offset to integer types</dd>
</dl>

View file

@ -831,7 +831,7 @@ Do not modify directly.*
|FastGelu|*in* X:**T**<br> *in* bias:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(float), tensor(float16)|
|FusedConv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *in* Z:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|FusedMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|GatedRelativePositionBias|*in* query_layer:**T**<br> *in* query_bias:**T**<br> *in* rel_pos:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* eco_a:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|GatedRelativePositionBias|*in* query_layer:**T**<br> *in* query_bias:**T**<br> *in* rel_pos:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* eco_a:**T**<br> *in* token_offset:**M**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|

View file

@ -100,17 +100,33 @@ Status GatedRelativePositionBias<T>::ComputeInternal(OpKernelContext* context) c
const Tensor& weight_tensor = *context->Input<Tensor>(3);
const Tensor& bias_tensor = *context->Input<Tensor>(4);
const Tensor& eco_a_tensor = *context->Input<Tensor>(5);
const Tensor* token_offset_tensor = context->Input<Tensor>(6);
const auto& query_dims = query_tensor.Shape().GetDims();
ORT_ENFORCE(query_dims.size() == 3);
ORT_ENFORCE(query_dims[2] > 0);
ORT_ENFORCE(query_dims[2] % num_heads_ == 0);
const auto batch_size = SafeInt<int>(query_dims[0]);
const auto seq_len = SafeInt<int>(query_dims[1]);
const auto head_size = SafeInt<int>(query_dims[2] / num_heads_);
bool is_padding_removed = (token_offset_tensor != nullptr);
ORT_ENFORCE(query_dims.size() == (is_padding_removed ? 2 : 3));
int hidden_size = static_cast<int>(query_dims[query_dims.size() - 1]);
ORT_ENFORCE(hidden_size > 0);
ORT_ENFORCE(hidden_size % num_heads_ == 0);
const auto head_size = SafeInt<int>(hidden_size / num_heads_);
int batch_size;
int seq_len;
int token_count = 0;
if (is_padding_removed) {
const auto& token_offset_dims = token_offset_tensor->Shape().GetDims();
batch_size = SafeInt<int>(token_offset_dims[0]);
seq_len = SafeInt<int>(token_offset_dims[1]);
token_count = SafeInt<int>(query_dims[0]);
} else {
batch_size = SafeInt<int>(query_dims[0]);
seq_len = SafeInt<int>(query_dims[1]);
}
ORT_ENFORCE(query_bias_tensor.Shape().NumDimensions() == 1);
ORT_ENFORCE(query_bias_tensor.Shape()[0] == query_dims[2]);
ORT_ENFORCE(query_bias_tensor.Shape()[0] == hidden_size);
const auto& rel_pos_dims = rel_pos_tensor.Shape().GetDims();
ORT_ENFORCE(rel_pos_dims.size() == 4);
@ -149,21 +165,31 @@ Status GatedRelativePositionBias<T>::ComputeInternal(OpKernelContext* context) c
size_t workspace_size = sizeof(T) * (elements_in_query + (reuse_output ? (size_t)0 : elements_after_gemm));
auto workspace = GetScratchBuffer<void>(workspace_size, context->GetComputeStream());
// format 1: BxSx(NH * total_matrix) => matrix_to_transpose * (BxNxSxH)
constexpr int format = 1;
constexpr int total_maxtrix = 1;
constexpr int num_matrix_to_transpose = 1;
LaunchAddBiasTranspose(Stream(context), num_matrix_to_transpose, format, device_prop.maxThreadsPerBlock,
batch_size, seq_len, num_heads_, head_size,
reinterpret_cast<const CudaT*>(query_tensor.template Data<T>()),
reinterpret_cast<const CudaT*>(query_bias_tensor.template Data<T>()),
reinterpret_cast<CudaT*>(workspace.get()),
false, head_size, reinterpret_cast<CudaT*>(static_cast<CudaT*>(nullptr)), total_maxtrix);
cudaStream_t stream = Stream(context);
if (!is_padding_removed) {
// format 1: BxSx(NH * total_matrix) => matrix_to_transpose * (BxNxSxH)
constexpr int format = 1;
constexpr int total_maxtrix = 1;
constexpr int num_matrix_to_transpose = 1;
LaunchAddBiasTranspose(stream, num_matrix_to_transpose, format, device_prop.maxThreadsPerBlock,
batch_size, seq_len, num_heads_, head_size,
reinterpret_cast<const CudaT*>(query_tensor.Data<T>()),
reinterpret_cast<const CudaT*>(query_bias_tensor.Data<T>()),
reinterpret_cast<CudaT*>(workspace.get()),
false, head_size, reinterpret_cast<CudaT*>(static_cast<CudaT*>(nullptr)), total_maxtrix);
} else {
RestorePaddingAddBiasTranspose(reinterpret_cast<const CudaT*>(query_tensor.Data<T>()),
reinterpret_cast<const CudaT*>(query_bias_tensor.Data<T>()),
reinterpret_cast<CudaT*>(workspace.get()),
batch_size, seq_len, num_heads_, head_size,
token_offset_tensor->Data<int>(),
token_count, stream);
}
// reuse output if possible
CudaT* gemm_output = reuse_output ? reinterpret_cast<CudaT*>(output->template MutableData<T>())
: (reinterpret_cast<CudaT*>(workspace.get()) + elements_in_query);
int ld_gemm_output = reuse_output ? seq_len : D;
int ld_gemm_output = reuse_output ? seq_len : static_cast<int>(D);
const CudaT one = ToCudaType<T>::FromFloat(1.0f);
const CudaT zero = ToCudaType<T>::FromFloat(0.0f);
@ -177,7 +203,7 @@ Status GatedRelativePositionBias<T>::ComputeInternal(OpKernelContext* context) c
&zero, gemm_output, ld_gemm_output, device_prop));
auto status = LaunchGatedRelativePositionBiasKernel<CudaT>(
device_prop, Stream(context),
device_prop, stream,
reinterpret_cast<CudaT*>(output->template MutableData<T>()),
reinterpret_cast<const CudaT*>(rel_pos_tensor.template Data<T>()),
reinterpret_cast<const CudaT*>(gemm_output),

View file

@ -29,6 +29,146 @@ namespace cuda {
using namespace onnxruntime::cuda;
static constexpr int32_t kMAX_THREADS_PER_BLOCK = 256;
// Grid: (S, B)
// Block: 256
// For packed input
// query: TxNxH
// Output: BxNxSxH
// Where:
// T is token_count
// B is batch_size
// S is sequence_length
// N is num_heads
// H is head_size
template <typename T>
__global__ void TransposeQKV_TNH_3BNSH(
const T* query,
const T* biases,
int32_t N,
int32_t H_QK,
T* q,
const int32_t* token_offset,
int32_t token_count) {
int s = blockIdx.x;
int b = blockIdx.y;
int S = gridDim.x;
const int packing_token_idx = b * S + s;
const int padding_token_idx = token_offset[packing_token_idx];
b = padding_token_idx / S;
s = padding_token_idx % S;
const int D_QK = N * H_QK;
query += packing_token_idx * D_QK;
q += (b * N * S + s) * H_QK;
if (packing_token_idx < token_count) {
for (int i = threadIdx.x; i < D_QK; i += blockDim.x) {
int h = i % H_QK;
int n = i / H_QK;
q[n * S * H_QK + h] = (biases == nullptr) ? query[i] : (query[i] + biases[i]);
}
} else {
for (int i = threadIdx.x; i < D_QK; i += blockDim.x) {
int h = i % H_QK;
int n = i / H_QK;
q[n * S * H_QK + h] = (biases == nullptr) ? T{} : biases[i];
}
}
}
template <typename T>
void InvokeTranspose(
const T* query, const T* bias, T* output,
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const int32_t* token_offset, int32_t token_count, cudaStream_t stream) {
const dim3 grid(sequence_length, batch_size);
TransposeQKV_TNH_3BNSH<T><<<grid, kMAX_THREADS_PER_BLOCK, 0, stream>>>(
query,
bias,
num_heads,
qk_head_size,
output,
token_offset,
token_count);
}
template <typename T>
struct T4;
template <>
struct T4<float> {
using Type = float4;
};
template <>
struct T4<half> {
using Type = Half4;
};
template <typename T>
struct T2;
template <>
struct T2<float> {
using Type = float2;
};
template <>
struct T2<half> {
using Type = half2;
};
template <typename T>
void RestorePaddingAddBiasTranspose(
const T* query, const T* bias, T* output,
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const int32_t* token_offset, int32_t token_count, cudaStream_t stream) {
if (0 == (qk_head_size & 3)) {
using T4Type = typename T4<T>::Type;
const int H = qk_head_size / 4;
const T4Type* query2 = reinterpret_cast<const T4Type*>(query);
const T4Type* bias2 = reinterpret_cast<const T4Type*>(bias);
T4Type* output2 = reinterpret_cast<T4Type*>(output);
InvokeTranspose<T4Type>(
query2, bias2, output2,
batch_size, sequence_length,
num_heads, H,
token_offset, token_count, stream);
} else if (0 == (qk_head_size & 1)) {
using T2Type = typename T2<T>::Type;
const int H = qk_head_size / 2;
const T2Type* query2 = reinterpret_cast<const T2Type*>(query);
const T2Type* bias2 = reinterpret_cast<const T2Type*>(bias);
T2Type* output2 = reinterpret_cast<T2Type*>(output);
InvokeTranspose<T2Type>(
query2, bias2, output2,
batch_size, sequence_length,
num_heads, H,
token_offset, token_count, stream);
} else {
InvokeTranspose<T>(
query, bias, output,
batch_size, sequence_length,
num_heads, qk_head_size,
token_offset, token_count, stream);
}
}
template void RestorePaddingAddBiasTranspose(
const float* query, const float* bias, float* output,
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const int32_t* token_offset, int32_t token_count, cudaStream_t stream);
template void RestorePaddingAddBiasTranspose(
const half* query, const half* bias, half* output,
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const int32_t* token_offset, int32_t token_count, cudaStream_t stream);
template <typename T>
__global__ void buildRelativeAttentionBias(T* relative_attention_bias,
const T* relative_attention_bias_table,

View file

@ -36,6 +36,12 @@ Status LaunchGatedRelativePositionBiasKernel(
const int D,
const int ldqw);
template <typename T>
void RestorePaddingAddBiasTranspose(
const T* query, const T* bias, T* output,
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const int32_t* token_offset, int32_t token_count, cudaStream_t stream);
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -1322,25 +1322,45 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
OpSchema()
.SetDoc(GatedRelativePositionBias_ver1_doc)
.Attr("num_heads", "Number of attention heads", AttributeProto::INT)
.Input(0, "query_layer", "tensor with shape (batch_size, seq_len, num_heads x head_size)", "T")
.Input(0, "query_layer", "tensor with shape (batch_size, seq_len, num_heads x head_size) or (token_count, num_heads x head_size)", "T")
.Input(1, "query_bias", "1-d tensor with shape (num_heads x head_size)", "T")
.Input(2, "rel_pos", "tensor with shape (1, num_head, seq_len, seq_len)", "T")
.Input(3, "weight", "gemm weight for the gated_ur_linear, shape (head_size, D), D is divisible by 2", "T")
.Input(4, "bias", "bias for the gated_ur_linear, shape (D)", "T")
.Input(5, "eco_a", "tensor of shape (1, num_heads, 1, 1)", "T")
.Input(6, "token_offset", "offset of each token with shape (batch_size, seq_len)", "M", OpSchema::Optional)
.Output(0, "output", "output tensor with shape (batch_size, num_heads, seq_len, seq_len)", "T")
.TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.")
.TypeConstraint("M", {"tensor(int32)"}, "Constrain token_offset to integer types")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
int64_t num_heads = getAttribute(ctx, "num_heads", -1L);
if (hasInputShape(ctx, 0)) {
auto& query_layer_shape = getInputShape(ctx, 0);
// When padding is removed:
// query_layer: (token_count, num_heads x head_size)
// token_offset: (batch_size, seq_len)
// Otherwise:
// query_layer: (batch_size, seq_len, num_heads x head_size)
// token_offset: None
// Output shape: (batch_size, num_heads, seq_len, seq_len)
if (hasInputShape(ctx, 6)) {
auto& token_offset_shape = getInputShape(ctx, 6);
TensorShapeProto output_shape;
*output_shape.add_dim() = query_layer_shape.dim(0);
*output_shape.add_dim() = token_offset_shape.dim(0);
output_shape.add_dim()->set_dim_value(num_heads);
*output_shape.add_dim() = query_layer_shape.dim(1);
*output_shape.add_dim() = query_layer_shape.dim(1);
*output_shape.add_dim() = token_offset_shape.dim(1);
*output_shape.add_dim() = token_offset_shape.dim(1);
updateOutputShape(ctx, 0, output_shape);
} else if (hasInputShape(ctx, 0)) {
auto& query_layer_shape = getInputShape(ctx, 0);
if (query_layer_shape.dim().size() == 3) {
TensorShapeProto output_shape;
*output_shape.add_dim() = query_layer_shape.dim(0);
output_shape.add_dim()->set_dim_value(num_heads);
*output_shape.add_dim() = query_layer_shape.dim(1);
*output_shape.add_dim() = query_layer_shape.dim(1);
updateOutputShape(ctx, 0, output_shape);
}
}
}));

View file

@ -227,7 +227,9 @@ static void RunGatedRelativePositionBiasTest(
int num_heads,
int head_size,
int D,
bool use_float16 = false) {
bool use_float16,
const std::vector<int>& token_offset,
int token_count) {
int min_cuda_architecture = use_float16 ? 530 : 0;
bool enable_cuda = HasCudaEnvironment(min_cuda_architecture);
@ -243,21 +245,32 @@ static void RunGatedRelativePositionBiasTest(
std::vector<int64_t> eco_a_dims = {1, num_heads, 1, 1};
std::vector<int64_t> output_dims = {batch_size, num_heads, seq_len, seq_len};
std::vector<int64_t> packed_query_dims = {token_count, num_heads * head_size};
std::vector<int64_t> token_offset_dims = {batch_size, seq_len};
bool is_padding_removed = token_offset.size() > 0;
if (use_float16) {
tester.AddInput<MLFloat16>("query_layer", query_layer_dims, ToFloat16(query_layer));
tester.AddInput<MLFloat16>("query_layer", is_padding_removed ? packed_query_dims : query_layer_dims, ToFloat16(query_layer));
tester.AddInput<MLFloat16>("query_bias", query_bias_dims, ToFloat16(query_bias));
tester.AddInput<MLFloat16>("rel_pos", rel_pos_dims, ToFloat16(rel_pos));
tester.AddInput<MLFloat16>("weight", weight_dims, ToFloat16(weight));
tester.AddInput<MLFloat16>("bias", bias_dims, ToFloat16(bias));
tester.AddInput<MLFloat16>("eco_a", eco_a_dims, ToFloat16(eco_a));
if (is_padding_removed) {
tester.AddInput<int>("token_offset", token_offset_dims, token_offset);
}
tester.AddOutput<MLFloat16>("output", output_dims, ToFloat16(output));
} else {
tester.AddInput<float>("query_layer", query_layer_dims, query_layer);
tester.AddInput<float>("query_layer", is_padding_removed ? packed_query_dims : query_layer_dims, query_layer);
tester.AddInput<float>("query_bias", query_bias_dims, query_bias);
tester.AddInput<float>("rel_pos", rel_pos_dims, rel_pos);
tester.AddInput<float>("weight", weight_dims, weight);
tester.AddInput<float>("bias", bias_dims, bias);
tester.AddInput<float>("eco_a", eco_a_dims, eco_a);
if (is_padding_removed) {
tester.AddInput<int>("token_offset", token_offset_dims, token_offset);
}
tester.AddOutput<float>("output", output_dims, output);
}
@ -304,8 +317,10 @@ TEST(GatedRelativePositionBiasTest, FP16_BSNHD_1x3x2x4x8) {
0.88587445f, 0.42708054f, 1.0246648f, 0.05810945f, 0.2430356f, 0.4244021f, 1.428723f, 1.3902748f,
0.48772895f, 0.84479123f};
const std::vector<int> token_offset;
int token_count = 0;
RunGatedRelativePositionBiasTest(query_layer, query_bias, rel_pos, weight, bias, eco_a, output,
batch_size, seq_len, num_heads, head_size, D, true);
batch_size, seq_len, num_heads, head_size, D, true, token_offset, token_count);
}
TEST(GatedRelativePositionBiasTest, FP32_BSNHD_2x3x2x4x8) {
@ -350,8 +365,10 @@ TEST(GatedRelativePositionBiasTest, FP32_BSNHD_2x3x2x4x8) {
0.37552574f, 1.1995038f, 1.4269164f, 0.47112313f, 0.5597632f, 0.6641063f, 0.87367094f, 1.056893f,
0.12367466f, 0.34158388f, 0.7510766f, 0.98590875f};
const std::vector<int> token_offset;
int token_count = 0;
RunGatedRelativePositionBiasTest(query_layer, query_bias, rel_pos, weight, bias, eco_a, output,
batch_size, seq_len, num_heads, head_size, D, false);
batch_size, seq_len, num_heads, head_size, D, false, token_offset, token_count);
}
TEST(GatedRelativePositionBiasTest, FP32_LongSeq_BSNHD_2x5x2x4x4) {
@ -410,17 +427,19 @@ TEST(GatedRelativePositionBiasTest, FP32_LongSeq_BSNHD_2x5x2x4x4) {
0.48692167f, 0.33312735f, 0.4217717f, 0.117013805f, 0.5107221f, 0.78737986f, 0.22609876f, 0.6166911f,
1.1153911f, 0.5832259f, 0.6681177f, 0.59397215f};
const std::vector<int> token_offset;
int token_count = 0;
RunGatedRelativePositionBiasTest(query_layer, query_bias, rel_pos, weight, bias, eco_a, output,
batch_size, seq_len, num_heads, head_size, D, false);
batch_size, seq_len, num_heads, head_size, D, false, token_offset, token_count);
}
TEST(GatedRelativePositionBiasTest, FP16_BSNHD_2x8x2x4x8) {
TEST(GatedRelativePositionBiasTest, FP16_BSNHD_2x8x2x4x8_NoPadding) {
constexpr int batch_size = 2;
constexpr int num_heads = 2;
constexpr int seq_len = 8;
constexpr int head_size = 4;
constexpr int D = 8;
const std::vector<int64_t> query_layer_dim = {2, 8, 8};
const std::vector<int64_t> query_layer_dim = {16, 8};
const std::vector<float> query_layer = {
0.4962566f, 0.7682218f, 0.08847743f, 0.13203049f, 0.30742282f, 0.6340787f, 0.4900934f, 0.89644474f,
0.45562798f, 0.6323063f, 0.34889346f, 0.4017173f, 0.022325754f, 0.16885895f, 0.29388845f, 0.5185218f,
@ -506,8 +525,10 @@ TEST(GatedRelativePositionBiasTest, FP16_BSNHD_2x8x2x4x8) {
0.57226765f, 0.46851522f, 0.26718724f, 0.6390965f, 1.0312729f, 0.39947683f, 0.22935463f, 0.35571814f,
0.005002509f, 0.82025534f, 0.29372898f, 0.18800265f, 0.2395663f, 0.8900865f, 0.8644386f, 0.998915f};
const std::vector<int> token_offset = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
int token_count = 16;
RunGatedRelativePositionBiasTest(query_layer, query_bias, rel_pos, weight, bias, eco_a, output,
batch_size, seq_len, num_heads, head_size, D, true);
batch_size, seq_len, num_heads, head_size, D, true, token_offset, token_count);
}
} // namespace test