From 50bf310dea83dca2001e70671077bc5efd3a999f Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 1 Aug 2023 16:39:09 -0700 Subject: [PATCH] [CUDA] RelativePositionBias supports input with padding removed (#16923) update RelativePositionBias to support input with padding removed. - [x] add bias transpose kernel - [x] add test - [x] update operator document --- docs/ContribOperators.md | 8 +- docs/OperatorKernels.md | 2 +- .../cuda/bert/relative_attn_bias.cc | 64 +++++--- .../cuda/bert/relative_attn_bias_impl.cu | 140 ++++++++++++++++++ .../cuda/bert/relative_attn_bias_impl.h | 6 + .../core/graph/contrib_ops/bert_defs.cc | 32 +++- .../relative_attention_bias_test.cc | 39 +++-- 7 files changed, 254 insertions(+), 37 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 0706e0f527..77b84e2b3f 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -1886,11 +1886,11 @@ This version of the operator has been available since version 1 of the 'com.micr
Number of attention heads
-#### Inputs +#### Inputs (6 - 7)
query_layer : T
-
tensor with shape (batch_size, seq_len, num_heads x head_size)
+
tensor with shape (batch_size, seq_len, num_heads x head_size) or (token_count, num_heads x head_size)
query_bias : T
1-d tensor with shape (num_heads x head_size)
rel_pos : T
@@ -1901,6 +1901,8 @@ This version of the operator has been available since version 1 of the 'com.micr
bias for the gated_ur_linear, shape (D)
eco_a : T
tensor of shape (1, num_heads, 1, 1)
+
token_offset (optional) : M
+
offset of each token with shape (batch_size, seq_len)
#### Outputs @@ -1915,6 +1917,8 @@ This version of the operator has been available since version 1 of the 'com.micr
T : tensor(float), tensor(float16)
Constrain input and output types to float tensors.
+
M : tensor(int32)
+
Constrain token_offset to integer types
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index bb57a8f817..1531d772c0 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -831,7 +831,7 @@ Do not modify directly.* |FastGelu|*in* X:**T**
*in* bias:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(float), tensor(float16)| |FusedConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*in* Z:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |FusedMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| -|GatedRelativePositionBias|*in* query_layer:**T**
*in* query_bias:**T**
*in* rel_pos:**T**
*in* weight:**T**
*in* bias:**T**
*in* eco_a:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| +|GatedRelativePositionBias|*in* query_layer:**T**
*in* query_bias:**T**
*in* rel_pos:**T**
*in* weight:**T**
*in* bias:**T**
*in* eco_a:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc index ceae4d511a..0284e6f2aa 100644 --- a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc +++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc @@ -100,17 +100,33 @@ Status GatedRelativePositionBias::ComputeInternal(OpKernelContext* context) c const Tensor& weight_tensor = *context->Input(3); const Tensor& bias_tensor = *context->Input(4); const Tensor& eco_a_tensor = *context->Input(5); + const Tensor* token_offset_tensor = context->Input(6); const auto& query_dims = query_tensor.Shape().GetDims(); - ORT_ENFORCE(query_dims.size() == 3); - ORT_ENFORCE(query_dims[2] > 0); - ORT_ENFORCE(query_dims[2] % num_heads_ == 0); - const auto batch_size = SafeInt(query_dims[0]); - const auto seq_len = SafeInt(query_dims[1]); - const auto head_size = SafeInt(query_dims[2] / num_heads_); + + bool is_padding_removed = (token_offset_tensor != nullptr); + ORT_ENFORCE(query_dims.size() == (is_padding_removed ? 2 : 3)); + int hidden_size = static_cast(query_dims[query_dims.size() - 1]); + + ORT_ENFORCE(hidden_size > 0); + ORT_ENFORCE(hidden_size % num_heads_ == 0); + const auto head_size = SafeInt(hidden_size / num_heads_); + + int batch_size; + int seq_len; + int token_count = 0; + if (is_padding_removed) { + const auto& token_offset_dims = token_offset_tensor->Shape().GetDims(); + batch_size = SafeInt(token_offset_dims[0]); + seq_len = SafeInt(token_offset_dims[1]); + token_count = SafeInt(query_dims[0]); + } else { + batch_size = SafeInt(query_dims[0]); + seq_len = SafeInt(query_dims[1]); + } ORT_ENFORCE(query_bias_tensor.Shape().NumDimensions() == 1); - ORT_ENFORCE(query_bias_tensor.Shape()[0] == query_dims[2]); + ORT_ENFORCE(query_bias_tensor.Shape()[0] == hidden_size); const auto& rel_pos_dims = rel_pos_tensor.Shape().GetDims(); ORT_ENFORCE(rel_pos_dims.size() == 4); @@ -149,21 +165,31 @@ Status GatedRelativePositionBias::ComputeInternal(OpKernelContext* context) c size_t workspace_size = sizeof(T) * (elements_in_query + (reuse_output ? (size_t)0 : elements_after_gemm)); auto workspace = GetScratchBuffer(workspace_size, context->GetComputeStream()); - // format 1: BxSx(NH * total_matrix) => matrix_to_transpose * (BxNxSxH) - constexpr int format = 1; - constexpr int total_maxtrix = 1; - constexpr int num_matrix_to_transpose = 1; - LaunchAddBiasTranspose(Stream(context), num_matrix_to_transpose, format, device_prop.maxThreadsPerBlock, - batch_size, seq_len, num_heads_, head_size, - reinterpret_cast(query_tensor.template Data()), - reinterpret_cast(query_bias_tensor.template Data()), - reinterpret_cast(workspace.get()), - false, head_size, reinterpret_cast(static_cast(nullptr)), total_maxtrix); + cudaStream_t stream = Stream(context); + if (!is_padding_removed) { + // format 1: BxSx(NH * total_matrix) => matrix_to_transpose * (BxNxSxH) + constexpr int format = 1; + constexpr int total_maxtrix = 1; + constexpr int num_matrix_to_transpose = 1; + LaunchAddBiasTranspose(stream, num_matrix_to_transpose, format, device_prop.maxThreadsPerBlock, + batch_size, seq_len, num_heads_, head_size, + reinterpret_cast(query_tensor.Data()), + reinterpret_cast(query_bias_tensor.Data()), + reinterpret_cast(workspace.get()), + false, head_size, reinterpret_cast(static_cast(nullptr)), total_maxtrix); + } else { + RestorePaddingAddBiasTranspose(reinterpret_cast(query_tensor.Data()), + reinterpret_cast(query_bias_tensor.Data()), + reinterpret_cast(workspace.get()), + batch_size, seq_len, num_heads_, head_size, + token_offset_tensor->Data(), + token_count, stream); + } // reuse output if possible CudaT* gemm_output = reuse_output ? reinterpret_cast(output->template MutableData()) : (reinterpret_cast(workspace.get()) + elements_in_query); - int ld_gemm_output = reuse_output ? seq_len : D; + int ld_gemm_output = reuse_output ? seq_len : static_cast(D); const CudaT one = ToCudaType::FromFloat(1.0f); const CudaT zero = ToCudaType::FromFloat(0.0f); @@ -177,7 +203,7 @@ Status GatedRelativePositionBias::ComputeInternal(OpKernelContext* context) c &zero, gemm_output, ld_gemm_output, device_prop)); auto status = LaunchGatedRelativePositionBiasKernel( - device_prop, Stream(context), + device_prop, stream, reinterpret_cast(output->template MutableData()), reinterpret_cast(rel_pos_tensor.template Data()), reinterpret_cast(gemm_output), diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu index ebe87158d1..ad8acd9211 100644 --- a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu @@ -29,6 +29,146 @@ namespace cuda { using namespace onnxruntime::cuda; +static constexpr int32_t kMAX_THREADS_PER_BLOCK = 256; + +// Grid: (S, B) +// Block: 256 +// For packed input +// query: TxNxH +// Output: BxNxSxH +// Where: +// T is token_count +// B is batch_size +// S is sequence_length +// N is num_heads +// H is head_size +template +__global__ void TransposeQKV_TNH_3BNSH( + const T* query, + const T* biases, + int32_t N, + int32_t H_QK, + T* q, + const int32_t* token_offset, + int32_t token_count) { + int s = blockIdx.x; + int b = blockIdx.y; + + int S = gridDim.x; + + const int packing_token_idx = b * S + s; + const int padding_token_idx = token_offset[packing_token_idx]; + b = padding_token_idx / S; + s = padding_token_idx % S; + + const int D_QK = N * H_QK; + query += packing_token_idx * D_QK; + + q += (b * N * S + s) * H_QK; + + if (packing_token_idx < token_count) { + for (int i = threadIdx.x; i < D_QK; i += blockDim.x) { + int h = i % H_QK; + int n = i / H_QK; + q[n * S * H_QK + h] = (biases == nullptr) ? query[i] : (query[i] + biases[i]); + } + } else { + for (int i = threadIdx.x; i < D_QK; i += blockDim.x) { + int h = i % H_QK; + int n = i / H_QK; + q[n * S * H_QK + h] = (biases == nullptr) ? T{} : biases[i]; + } + } +} + +template +void InvokeTranspose( + const T* query, const T* bias, T* output, + const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size, + const int32_t* token_offset, int32_t token_count, cudaStream_t stream) { + const dim3 grid(sequence_length, batch_size); + TransposeQKV_TNH_3BNSH<<>>( + query, + bias, + num_heads, + qk_head_size, + output, + token_offset, + token_count); +} + +template +struct T4; + +template <> +struct T4 { + using Type = float4; +}; + +template <> +struct T4 { + using Type = Half4; +}; + +template +struct T2; + +template <> +struct T2 { + using Type = float2; +}; + +template <> +struct T2 { + using Type = half2; +}; + +template +void RestorePaddingAddBiasTranspose( + const T* query, const T* bias, T* output, + const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size, + const int32_t* token_offset, int32_t token_count, cudaStream_t stream) { + if (0 == (qk_head_size & 3)) { + using T4Type = typename T4::Type; + const int H = qk_head_size / 4; + const T4Type* query2 = reinterpret_cast(query); + const T4Type* bias2 = reinterpret_cast(bias); + T4Type* output2 = reinterpret_cast(output); + InvokeTranspose( + query2, bias2, output2, + batch_size, sequence_length, + num_heads, H, + token_offset, token_count, stream); + } else if (0 == (qk_head_size & 1)) { + using T2Type = typename T2::Type; + const int H = qk_head_size / 2; + const T2Type* query2 = reinterpret_cast(query); + const T2Type* bias2 = reinterpret_cast(bias); + T2Type* output2 = reinterpret_cast(output); + InvokeTranspose( + query2, bias2, output2, + batch_size, sequence_length, + num_heads, H, + token_offset, token_count, stream); + } else { + InvokeTranspose( + query, bias, output, + batch_size, sequence_length, + num_heads, qk_head_size, + token_offset, token_count, stream); + } +} + +template void RestorePaddingAddBiasTranspose( + const float* query, const float* bias, float* output, + const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size, + const int32_t* token_offset, int32_t token_count, cudaStream_t stream); + +template void RestorePaddingAddBiasTranspose( + const half* query, const half* bias, half* output, + const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size, + const int32_t* token_offset, int32_t token_count, cudaStream_t stream); + template __global__ void buildRelativeAttentionBias(T* relative_attention_bias, const T* relative_attention_bias_table, diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.h b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.h index 572db0c023..74edc49dff 100644 --- a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.h @@ -36,6 +36,12 @@ Status LaunchGatedRelativePositionBiasKernel( const int D, const int ldqw); +template +void RestorePaddingAddBiasTranspose( + const T* query, const T* bias, T* output, + const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size, + const int32_t* token_offset, int32_t token_count, cudaStream_t stream); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 6bce65e8c8..d2ff7e5351 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1322,25 +1322,45 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema() .SetDoc(GatedRelativePositionBias_ver1_doc) .Attr("num_heads", "Number of attention heads", AttributeProto::INT) - .Input(0, "query_layer", "tensor with shape (batch_size, seq_len, num_heads x head_size)", "T") + .Input(0, "query_layer", "tensor with shape (batch_size, seq_len, num_heads x head_size) or (token_count, num_heads x head_size)", "T") .Input(1, "query_bias", "1-d tensor with shape (num_heads x head_size)", "T") .Input(2, "rel_pos", "tensor with shape (1, num_head, seq_len, seq_len)", "T") .Input(3, "weight", "gemm weight for the gated_ur_linear, shape (head_size, D), D is divisible by 2", "T") .Input(4, "bias", "bias for the gated_ur_linear, shape (D)", "T") .Input(5, "eco_a", "tensor of shape (1, num_heads, 1, 1)", "T") + .Input(6, "token_offset", "offset of each token with shape (batch_size, seq_len)", "M", OpSchema::Optional) .Output(0, "output", "output tensor with shape (batch_size, num_heads, seq_len, seq_len)", "T") .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") + .TypeConstraint("M", {"tensor(int32)"}, "Constrain token_offset to integer types") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { propagateElemTypeFromInputToOutput(ctx, 0, 0); int64_t num_heads = getAttribute(ctx, "num_heads", -1L); - if (hasInputShape(ctx, 0)) { - auto& query_layer_shape = getInputShape(ctx, 0); + + // When padding is removed: + // query_layer: (token_count, num_heads x head_size) + // token_offset: (batch_size, seq_len) + // Otherwise: + // query_layer: (batch_size, seq_len, num_heads x head_size) + // token_offset: None + // Output shape: (batch_size, num_heads, seq_len, seq_len) + if (hasInputShape(ctx, 6)) { + auto& token_offset_shape = getInputShape(ctx, 6); TensorShapeProto output_shape; - *output_shape.add_dim() = query_layer_shape.dim(0); + *output_shape.add_dim() = token_offset_shape.dim(0); output_shape.add_dim()->set_dim_value(num_heads); - *output_shape.add_dim() = query_layer_shape.dim(1); - *output_shape.add_dim() = query_layer_shape.dim(1); + *output_shape.add_dim() = token_offset_shape.dim(1); + *output_shape.add_dim() = token_offset_shape.dim(1); updateOutputShape(ctx, 0, output_shape); + } else if (hasInputShape(ctx, 0)) { + auto& query_layer_shape = getInputShape(ctx, 0); + if (query_layer_shape.dim().size() == 3) { + TensorShapeProto output_shape; + *output_shape.add_dim() = query_layer_shape.dim(0); + output_shape.add_dim()->set_dim_value(num_heads); + *output_shape.add_dim() = query_layer_shape.dim(1); + *output_shape.add_dim() = query_layer_shape.dim(1); + updateOutputShape(ctx, 0, output_shape); + } } })); diff --git a/onnxruntime/test/contrib_ops/relative_attention_bias_test.cc b/onnxruntime/test/contrib_ops/relative_attention_bias_test.cc index 5c08386762..6885a460d7 100644 --- a/onnxruntime/test/contrib_ops/relative_attention_bias_test.cc +++ b/onnxruntime/test/contrib_ops/relative_attention_bias_test.cc @@ -227,7 +227,9 @@ static void RunGatedRelativePositionBiasTest( int num_heads, int head_size, int D, - bool use_float16 = false) { + bool use_float16, + const std::vector& token_offset, + int token_count) { int min_cuda_architecture = use_float16 ? 530 : 0; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); @@ -243,21 +245,32 @@ static void RunGatedRelativePositionBiasTest( std::vector eco_a_dims = {1, num_heads, 1, 1}; std::vector output_dims = {batch_size, num_heads, seq_len, seq_len}; + std::vector packed_query_dims = {token_count, num_heads * head_size}; + std::vector token_offset_dims = {batch_size, seq_len}; + bool is_padding_removed = token_offset.size() > 0; + if (use_float16) { - tester.AddInput("query_layer", query_layer_dims, ToFloat16(query_layer)); + tester.AddInput("query_layer", is_padding_removed ? packed_query_dims : query_layer_dims, ToFloat16(query_layer)); tester.AddInput("query_bias", query_bias_dims, ToFloat16(query_bias)); tester.AddInput("rel_pos", rel_pos_dims, ToFloat16(rel_pos)); tester.AddInput("weight", weight_dims, ToFloat16(weight)); tester.AddInput("bias", bias_dims, ToFloat16(bias)); tester.AddInput("eco_a", eco_a_dims, ToFloat16(eco_a)); + if (is_padding_removed) { + tester.AddInput("token_offset", token_offset_dims, token_offset); + } tester.AddOutput("output", output_dims, ToFloat16(output)); } else { - tester.AddInput("query_layer", query_layer_dims, query_layer); + tester.AddInput("query_layer", is_padding_removed ? packed_query_dims : query_layer_dims, query_layer); tester.AddInput("query_bias", query_bias_dims, query_bias); tester.AddInput("rel_pos", rel_pos_dims, rel_pos); tester.AddInput("weight", weight_dims, weight); tester.AddInput("bias", bias_dims, bias); tester.AddInput("eco_a", eco_a_dims, eco_a); + if (is_padding_removed) { + tester.AddInput("token_offset", token_offset_dims, token_offset); + } + tester.AddOutput("output", output_dims, output); } @@ -304,8 +317,10 @@ TEST(GatedRelativePositionBiasTest, FP16_BSNHD_1x3x2x4x8) { 0.88587445f, 0.42708054f, 1.0246648f, 0.05810945f, 0.2430356f, 0.4244021f, 1.428723f, 1.3902748f, 0.48772895f, 0.84479123f}; + const std::vector token_offset; + int token_count = 0; RunGatedRelativePositionBiasTest(query_layer, query_bias, rel_pos, weight, bias, eco_a, output, - batch_size, seq_len, num_heads, head_size, D, true); + batch_size, seq_len, num_heads, head_size, D, true, token_offset, token_count); } TEST(GatedRelativePositionBiasTest, FP32_BSNHD_2x3x2x4x8) { @@ -350,8 +365,10 @@ TEST(GatedRelativePositionBiasTest, FP32_BSNHD_2x3x2x4x8) { 0.37552574f, 1.1995038f, 1.4269164f, 0.47112313f, 0.5597632f, 0.6641063f, 0.87367094f, 1.056893f, 0.12367466f, 0.34158388f, 0.7510766f, 0.98590875f}; + const std::vector token_offset; + int token_count = 0; RunGatedRelativePositionBiasTest(query_layer, query_bias, rel_pos, weight, bias, eco_a, output, - batch_size, seq_len, num_heads, head_size, D, false); + batch_size, seq_len, num_heads, head_size, D, false, token_offset, token_count); } TEST(GatedRelativePositionBiasTest, FP32_LongSeq_BSNHD_2x5x2x4x4) { @@ -410,17 +427,19 @@ TEST(GatedRelativePositionBiasTest, FP32_LongSeq_BSNHD_2x5x2x4x4) { 0.48692167f, 0.33312735f, 0.4217717f, 0.117013805f, 0.5107221f, 0.78737986f, 0.22609876f, 0.6166911f, 1.1153911f, 0.5832259f, 0.6681177f, 0.59397215f}; + const std::vector token_offset; + int token_count = 0; RunGatedRelativePositionBiasTest(query_layer, query_bias, rel_pos, weight, bias, eco_a, output, - batch_size, seq_len, num_heads, head_size, D, false); + batch_size, seq_len, num_heads, head_size, D, false, token_offset, token_count); } -TEST(GatedRelativePositionBiasTest, FP16_BSNHD_2x8x2x4x8) { +TEST(GatedRelativePositionBiasTest, FP16_BSNHD_2x8x2x4x8_NoPadding) { constexpr int batch_size = 2; constexpr int num_heads = 2; constexpr int seq_len = 8; constexpr int head_size = 4; constexpr int D = 8; - const std::vector query_layer_dim = {2, 8, 8}; + const std::vector query_layer_dim = {16, 8}; const std::vector query_layer = { 0.4962566f, 0.7682218f, 0.08847743f, 0.13203049f, 0.30742282f, 0.6340787f, 0.4900934f, 0.89644474f, 0.45562798f, 0.6323063f, 0.34889346f, 0.4017173f, 0.022325754f, 0.16885895f, 0.29388845f, 0.5185218f, @@ -506,8 +525,10 @@ TEST(GatedRelativePositionBiasTest, FP16_BSNHD_2x8x2x4x8) { 0.57226765f, 0.46851522f, 0.26718724f, 0.6390965f, 1.0312729f, 0.39947683f, 0.22935463f, 0.35571814f, 0.005002509f, 0.82025534f, 0.29372898f, 0.18800265f, 0.2395663f, 0.8900865f, 0.8644386f, 0.998915f}; + const std::vector token_offset = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + int token_count = 16; RunGatedRelativePositionBiasTest(query_layer, query_bias, rel_pos, weight, bias, eco_a, output, - batch_size, seq_len, num_heads, head_size, D, true); + batch_size, seq_len, num_heads, head_size, D, true, token_offset, token_count); } } // namespace test