diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 0706e0f527..77b84e2b3f 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -1886,11 +1886,11 @@ This version of the operator has been available since version 1 of the 'com.micr
Number of attention heads
-#### Inputs
+#### Inputs (6 - 7)
- query_layer : T
-- tensor with shape (batch_size, seq_len, num_heads x head_size)
+- tensor with shape (batch_size, seq_len, num_heads x head_size) or (token_count, num_heads x head_size)
- query_bias : T
- 1-d tensor with shape (num_heads x head_size)
- rel_pos : T
@@ -1901,6 +1901,8 @@ This version of the operator has been available since version 1 of the 'com.micr
- bias for the gated_ur_linear, shape (D)
- eco_a : T
- tensor of shape (1, num_heads, 1, 1)
+- token_offset (optional) : M
+- offset of each token with shape (batch_size, seq_len)
#### Outputs
@@ -1915,6 +1917,8 @@ This version of the operator has been available since version 1 of the 'com.micr
- T : tensor(float), tensor(float16)
- Constrain input and output types to float tensors.
+- M : tensor(int32)
+- Constrain token_offset to integer types
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index bb57a8f817..1531d772c0 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -831,7 +831,7 @@ Do not modify directly.*
|FastGelu|*in* X:**T**
*in* bias:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(float), tensor(float16)|
|FusedConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*in* Z:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|FusedMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
-|GatedRelativePositionBias|*in* query_layer:**T**
*in* query_bias:**T**
*in* rel_pos:**T**
*in* weight:**T**
*in* bias:**T**
*in* eco_a:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
+|GatedRelativePositionBias|*in* query_layer:**T**
*in* query_bias:**T**
*in* rel_pos:**T**
*in* weight:**T**
*in* bias:**T**
*in* eco_a:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
|GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)|
diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc
index ceae4d511a..0284e6f2aa 100644
--- a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc
@@ -100,17 +100,33 @@ Status GatedRelativePositionBias::ComputeInternal(OpKernelContext* context) c
const Tensor& weight_tensor = *context->Input(3);
const Tensor& bias_tensor = *context->Input(4);
const Tensor& eco_a_tensor = *context->Input(5);
+ const Tensor* token_offset_tensor = context->Input(6);
const auto& query_dims = query_tensor.Shape().GetDims();
- ORT_ENFORCE(query_dims.size() == 3);
- ORT_ENFORCE(query_dims[2] > 0);
- ORT_ENFORCE(query_dims[2] % num_heads_ == 0);
- const auto batch_size = SafeInt(query_dims[0]);
- const auto seq_len = SafeInt(query_dims[1]);
- const auto head_size = SafeInt(query_dims[2] / num_heads_);
+
+ bool is_padding_removed = (token_offset_tensor != nullptr);
+ ORT_ENFORCE(query_dims.size() == (is_padding_removed ? 2 : 3));
+ int hidden_size = static_cast(query_dims[query_dims.size() - 1]);
+
+ ORT_ENFORCE(hidden_size > 0);
+ ORT_ENFORCE(hidden_size % num_heads_ == 0);
+ const auto head_size = SafeInt(hidden_size / num_heads_);
+
+ int batch_size;
+ int seq_len;
+ int token_count = 0;
+ if (is_padding_removed) {
+ const auto& token_offset_dims = token_offset_tensor->Shape().GetDims();
+ batch_size = SafeInt(token_offset_dims[0]);
+ seq_len = SafeInt(token_offset_dims[1]);
+ token_count = SafeInt(query_dims[0]);
+ } else {
+ batch_size = SafeInt(query_dims[0]);
+ seq_len = SafeInt(query_dims[1]);
+ }
ORT_ENFORCE(query_bias_tensor.Shape().NumDimensions() == 1);
- ORT_ENFORCE(query_bias_tensor.Shape()[0] == query_dims[2]);
+ ORT_ENFORCE(query_bias_tensor.Shape()[0] == hidden_size);
const auto& rel_pos_dims = rel_pos_tensor.Shape().GetDims();
ORT_ENFORCE(rel_pos_dims.size() == 4);
@@ -149,21 +165,31 @@ Status GatedRelativePositionBias::ComputeInternal(OpKernelContext* context) c
size_t workspace_size = sizeof(T) * (elements_in_query + (reuse_output ? (size_t)0 : elements_after_gemm));
auto workspace = GetScratchBuffer(workspace_size, context->GetComputeStream());
- // format 1: BxSx(NH * total_matrix) => matrix_to_transpose * (BxNxSxH)
- constexpr int format = 1;
- constexpr int total_maxtrix = 1;
- constexpr int num_matrix_to_transpose = 1;
- LaunchAddBiasTranspose(Stream(context), num_matrix_to_transpose, format, device_prop.maxThreadsPerBlock,
- batch_size, seq_len, num_heads_, head_size,
- reinterpret_cast(query_tensor.template Data()),
- reinterpret_cast(query_bias_tensor.template Data()),
- reinterpret_cast(workspace.get()),
- false, head_size, reinterpret_cast(static_cast(nullptr)), total_maxtrix);
+ cudaStream_t stream = Stream(context);
+ if (!is_padding_removed) {
+ // format 1: BxSx(NH * total_matrix) => matrix_to_transpose * (BxNxSxH)
+ constexpr int format = 1;
+ constexpr int total_maxtrix = 1;
+ constexpr int num_matrix_to_transpose = 1;
+ LaunchAddBiasTranspose(stream, num_matrix_to_transpose, format, device_prop.maxThreadsPerBlock,
+ batch_size, seq_len, num_heads_, head_size,
+ reinterpret_cast(query_tensor.Data()),
+ reinterpret_cast(query_bias_tensor.Data()),
+ reinterpret_cast(workspace.get()),
+ false, head_size, reinterpret_cast(static_cast(nullptr)), total_maxtrix);
+ } else {
+ RestorePaddingAddBiasTranspose(reinterpret_cast(query_tensor.Data()),
+ reinterpret_cast(query_bias_tensor.Data()),
+ reinterpret_cast(workspace.get()),
+ batch_size, seq_len, num_heads_, head_size,
+ token_offset_tensor->Data(),
+ token_count, stream);
+ }
// reuse output if possible
CudaT* gemm_output = reuse_output ? reinterpret_cast(output->template MutableData())
: (reinterpret_cast(workspace.get()) + elements_in_query);
- int ld_gemm_output = reuse_output ? seq_len : D;
+ int ld_gemm_output = reuse_output ? seq_len : static_cast(D);
const CudaT one = ToCudaType::FromFloat(1.0f);
const CudaT zero = ToCudaType::FromFloat(0.0f);
@@ -177,7 +203,7 @@ Status GatedRelativePositionBias::ComputeInternal(OpKernelContext* context) c
&zero, gemm_output, ld_gemm_output, device_prop));
auto status = LaunchGatedRelativePositionBiasKernel(
- device_prop, Stream(context),
+ device_prop, stream,
reinterpret_cast(output->template MutableData()),
reinterpret_cast(rel_pos_tensor.template Data()),
reinterpret_cast(gemm_output),
diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu
index ebe87158d1..ad8acd9211 100644
--- a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu
@@ -29,6 +29,146 @@ namespace cuda {
using namespace onnxruntime::cuda;
+static constexpr int32_t kMAX_THREADS_PER_BLOCK = 256;
+
+// Grid: (S, B)
+// Block: 256
+// For packed input
+// query: TxNxH
+// Output: BxNxSxH
+// Where:
+// T is token_count
+// B is batch_size
+// S is sequence_length
+// N is num_heads
+// H is head_size
+template
+__global__ void TransposeQKV_TNH_3BNSH(
+ const T* query,
+ const T* biases,
+ int32_t N,
+ int32_t H_QK,
+ T* q,
+ const int32_t* token_offset,
+ int32_t token_count) {
+ int s = blockIdx.x;
+ int b = blockIdx.y;
+
+ int S = gridDim.x;
+
+ const int packing_token_idx = b * S + s;
+ const int padding_token_idx = token_offset[packing_token_idx];
+ b = padding_token_idx / S;
+ s = padding_token_idx % S;
+
+ const int D_QK = N * H_QK;
+ query += packing_token_idx * D_QK;
+
+ q += (b * N * S + s) * H_QK;
+
+ if (packing_token_idx < token_count) {
+ for (int i = threadIdx.x; i < D_QK; i += blockDim.x) {
+ int h = i % H_QK;
+ int n = i / H_QK;
+ q[n * S * H_QK + h] = (biases == nullptr) ? query[i] : (query[i] + biases[i]);
+ }
+ } else {
+ for (int i = threadIdx.x; i < D_QK; i += blockDim.x) {
+ int h = i % H_QK;
+ int n = i / H_QK;
+ q[n * S * H_QK + h] = (biases == nullptr) ? T{} : biases[i];
+ }
+ }
+}
+
+template
+void InvokeTranspose(
+ const T* query, const T* bias, T* output,
+ const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
+ const int32_t* token_offset, int32_t token_count, cudaStream_t stream) {
+ const dim3 grid(sequence_length, batch_size);
+ TransposeQKV_TNH_3BNSH<<>>(
+ query,
+ bias,
+ num_heads,
+ qk_head_size,
+ output,
+ token_offset,
+ token_count);
+}
+
+template
+struct T4;
+
+template <>
+struct T4 {
+ using Type = float4;
+};
+
+template <>
+struct T4 {
+ using Type = Half4;
+};
+
+template
+struct T2;
+
+template <>
+struct T2 {
+ using Type = float2;
+};
+
+template <>
+struct T2 {
+ using Type = half2;
+};
+
+template
+void RestorePaddingAddBiasTranspose(
+ const T* query, const T* bias, T* output,
+ const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
+ const int32_t* token_offset, int32_t token_count, cudaStream_t stream) {
+ if (0 == (qk_head_size & 3)) {
+ using T4Type = typename T4::Type;
+ const int H = qk_head_size / 4;
+ const T4Type* query2 = reinterpret_cast(query);
+ const T4Type* bias2 = reinterpret_cast(bias);
+ T4Type* output2 = reinterpret_cast(output);
+ InvokeTranspose(
+ query2, bias2, output2,
+ batch_size, sequence_length,
+ num_heads, H,
+ token_offset, token_count, stream);
+ } else if (0 == (qk_head_size & 1)) {
+ using T2Type = typename T2::Type;
+ const int H = qk_head_size / 2;
+ const T2Type* query2 = reinterpret_cast(query);
+ const T2Type* bias2 = reinterpret_cast(bias);
+ T2Type* output2 = reinterpret_cast(output);
+ InvokeTranspose(
+ query2, bias2, output2,
+ batch_size, sequence_length,
+ num_heads, H,
+ token_offset, token_count, stream);
+ } else {
+ InvokeTranspose(
+ query, bias, output,
+ batch_size, sequence_length,
+ num_heads, qk_head_size,
+ token_offset, token_count, stream);
+ }
+}
+
+template void RestorePaddingAddBiasTranspose(
+ const float* query, const float* bias, float* output,
+ const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
+ const int32_t* token_offset, int32_t token_count, cudaStream_t stream);
+
+template void RestorePaddingAddBiasTranspose(
+ const half* query, const half* bias, half* output,
+ const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
+ const int32_t* token_offset, int32_t token_count, cudaStream_t stream);
+
template
__global__ void buildRelativeAttentionBias(T* relative_attention_bias,
const T* relative_attention_bias_table,
diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.h b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.h
index 572db0c023..74edc49dff 100644
--- a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.h
+++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.h
@@ -36,6 +36,12 @@ Status LaunchGatedRelativePositionBiasKernel(
const int D,
const int ldqw);
+template
+void RestorePaddingAddBiasTranspose(
+ const T* query, const T* bias, T* output,
+ const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
+ const int32_t* token_offset, int32_t token_count, cudaStream_t stream);
+
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
index 6bce65e8c8..d2ff7e5351 100644
--- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
@@ -1322,25 +1322,45 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
OpSchema()
.SetDoc(GatedRelativePositionBias_ver1_doc)
.Attr("num_heads", "Number of attention heads", AttributeProto::INT)
- .Input(0, "query_layer", "tensor with shape (batch_size, seq_len, num_heads x head_size)", "T")
+ .Input(0, "query_layer", "tensor with shape (batch_size, seq_len, num_heads x head_size) or (token_count, num_heads x head_size)", "T")
.Input(1, "query_bias", "1-d tensor with shape (num_heads x head_size)", "T")
.Input(2, "rel_pos", "tensor with shape (1, num_head, seq_len, seq_len)", "T")
.Input(3, "weight", "gemm weight for the gated_ur_linear, shape (head_size, D), D is divisible by 2", "T")
.Input(4, "bias", "bias for the gated_ur_linear, shape (D)", "T")
.Input(5, "eco_a", "tensor of shape (1, num_heads, 1, 1)", "T")
+ .Input(6, "token_offset", "offset of each token with shape (batch_size, seq_len)", "M", OpSchema::Optional)
.Output(0, "output", "output tensor with shape (batch_size, num_heads, seq_len, seq_len)", "T")
.TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.")
+ .TypeConstraint("M", {"tensor(int32)"}, "Constrain token_offset to integer types")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
int64_t num_heads = getAttribute(ctx, "num_heads", -1L);
- if (hasInputShape(ctx, 0)) {
- auto& query_layer_shape = getInputShape(ctx, 0);
+
+ // When padding is removed:
+ // query_layer: (token_count, num_heads x head_size)
+ // token_offset: (batch_size, seq_len)
+ // Otherwise:
+ // query_layer: (batch_size, seq_len, num_heads x head_size)
+ // token_offset: None
+ // Output shape: (batch_size, num_heads, seq_len, seq_len)
+ if (hasInputShape(ctx, 6)) {
+ auto& token_offset_shape = getInputShape(ctx, 6);
TensorShapeProto output_shape;
- *output_shape.add_dim() = query_layer_shape.dim(0);
+ *output_shape.add_dim() = token_offset_shape.dim(0);
output_shape.add_dim()->set_dim_value(num_heads);
- *output_shape.add_dim() = query_layer_shape.dim(1);
- *output_shape.add_dim() = query_layer_shape.dim(1);
+ *output_shape.add_dim() = token_offset_shape.dim(1);
+ *output_shape.add_dim() = token_offset_shape.dim(1);
updateOutputShape(ctx, 0, output_shape);
+ } else if (hasInputShape(ctx, 0)) {
+ auto& query_layer_shape = getInputShape(ctx, 0);
+ if (query_layer_shape.dim().size() == 3) {
+ TensorShapeProto output_shape;
+ *output_shape.add_dim() = query_layer_shape.dim(0);
+ output_shape.add_dim()->set_dim_value(num_heads);
+ *output_shape.add_dim() = query_layer_shape.dim(1);
+ *output_shape.add_dim() = query_layer_shape.dim(1);
+ updateOutputShape(ctx, 0, output_shape);
+ }
}
}));
diff --git a/onnxruntime/test/contrib_ops/relative_attention_bias_test.cc b/onnxruntime/test/contrib_ops/relative_attention_bias_test.cc
index 5c08386762..6885a460d7 100644
--- a/onnxruntime/test/contrib_ops/relative_attention_bias_test.cc
+++ b/onnxruntime/test/contrib_ops/relative_attention_bias_test.cc
@@ -227,7 +227,9 @@ static void RunGatedRelativePositionBiasTest(
int num_heads,
int head_size,
int D,
- bool use_float16 = false) {
+ bool use_float16,
+ const std::vector& token_offset,
+ int token_count) {
int min_cuda_architecture = use_float16 ? 530 : 0;
bool enable_cuda = HasCudaEnvironment(min_cuda_architecture);
@@ -243,21 +245,32 @@ static void RunGatedRelativePositionBiasTest(
std::vector eco_a_dims = {1, num_heads, 1, 1};
std::vector output_dims = {batch_size, num_heads, seq_len, seq_len};
+ std::vector packed_query_dims = {token_count, num_heads * head_size};
+ std::vector token_offset_dims = {batch_size, seq_len};
+ bool is_padding_removed = token_offset.size() > 0;
+
if (use_float16) {
- tester.AddInput("query_layer", query_layer_dims, ToFloat16(query_layer));
+ tester.AddInput("query_layer", is_padding_removed ? packed_query_dims : query_layer_dims, ToFloat16(query_layer));
tester.AddInput("query_bias", query_bias_dims, ToFloat16(query_bias));
tester.AddInput("rel_pos", rel_pos_dims, ToFloat16(rel_pos));
tester.AddInput("weight", weight_dims, ToFloat16(weight));
tester.AddInput("bias", bias_dims, ToFloat16(bias));
tester.AddInput("eco_a", eco_a_dims, ToFloat16(eco_a));
+ if (is_padding_removed) {
+ tester.AddInput("token_offset", token_offset_dims, token_offset);
+ }
tester.AddOutput("output", output_dims, ToFloat16(output));
} else {
- tester.AddInput("query_layer", query_layer_dims, query_layer);
+ tester.AddInput("query_layer", is_padding_removed ? packed_query_dims : query_layer_dims, query_layer);
tester.AddInput("query_bias", query_bias_dims, query_bias);
tester.AddInput("rel_pos", rel_pos_dims, rel_pos);
tester.AddInput("weight", weight_dims, weight);
tester.AddInput("bias", bias_dims, bias);
tester.AddInput("eco_a", eco_a_dims, eco_a);
+ if (is_padding_removed) {
+ tester.AddInput("token_offset", token_offset_dims, token_offset);
+ }
+
tester.AddOutput("output", output_dims, output);
}
@@ -304,8 +317,10 @@ TEST(GatedRelativePositionBiasTest, FP16_BSNHD_1x3x2x4x8) {
0.88587445f, 0.42708054f, 1.0246648f, 0.05810945f, 0.2430356f, 0.4244021f, 1.428723f, 1.3902748f,
0.48772895f, 0.84479123f};
+ const std::vector token_offset;
+ int token_count = 0;
RunGatedRelativePositionBiasTest(query_layer, query_bias, rel_pos, weight, bias, eco_a, output,
- batch_size, seq_len, num_heads, head_size, D, true);
+ batch_size, seq_len, num_heads, head_size, D, true, token_offset, token_count);
}
TEST(GatedRelativePositionBiasTest, FP32_BSNHD_2x3x2x4x8) {
@@ -350,8 +365,10 @@ TEST(GatedRelativePositionBiasTest, FP32_BSNHD_2x3x2x4x8) {
0.37552574f, 1.1995038f, 1.4269164f, 0.47112313f, 0.5597632f, 0.6641063f, 0.87367094f, 1.056893f,
0.12367466f, 0.34158388f, 0.7510766f, 0.98590875f};
+ const std::vector token_offset;
+ int token_count = 0;
RunGatedRelativePositionBiasTest(query_layer, query_bias, rel_pos, weight, bias, eco_a, output,
- batch_size, seq_len, num_heads, head_size, D, false);
+ batch_size, seq_len, num_heads, head_size, D, false, token_offset, token_count);
}
TEST(GatedRelativePositionBiasTest, FP32_LongSeq_BSNHD_2x5x2x4x4) {
@@ -410,17 +427,19 @@ TEST(GatedRelativePositionBiasTest, FP32_LongSeq_BSNHD_2x5x2x4x4) {
0.48692167f, 0.33312735f, 0.4217717f, 0.117013805f, 0.5107221f, 0.78737986f, 0.22609876f, 0.6166911f,
1.1153911f, 0.5832259f, 0.6681177f, 0.59397215f};
+ const std::vector token_offset;
+ int token_count = 0;
RunGatedRelativePositionBiasTest(query_layer, query_bias, rel_pos, weight, bias, eco_a, output,
- batch_size, seq_len, num_heads, head_size, D, false);
+ batch_size, seq_len, num_heads, head_size, D, false, token_offset, token_count);
}
-TEST(GatedRelativePositionBiasTest, FP16_BSNHD_2x8x2x4x8) {
+TEST(GatedRelativePositionBiasTest, FP16_BSNHD_2x8x2x4x8_NoPadding) {
constexpr int batch_size = 2;
constexpr int num_heads = 2;
constexpr int seq_len = 8;
constexpr int head_size = 4;
constexpr int D = 8;
- const std::vector query_layer_dim = {2, 8, 8};
+ const std::vector query_layer_dim = {16, 8};
const std::vector query_layer = {
0.4962566f, 0.7682218f, 0.08847743f, 0.13203049f, 0.30742282f, 0.6340787f, 0.4900934f, 0.89644474f,
0.45562798f, 0.6323063f, 0.34889346f, 0.4017173f, 0.022325754f, 0.16885895f, 0.29388845f, 0.5185218f,
@@ -506,8 +525,10 @@ TEST(GatedRelativePositionBiasTest, FP16_BSNHD_2x8x2x4x8) {
0.57226765f, 0.46851522f, 0.26718724f, 0.6390965f, 1.0312729f, 0.39947683f, 0.22935463f, 0.35571814f,
0.005002509f, 0.82025534f, 0.29372898f, 0.18800265f, 0.2395663f, 0.8900865f, 0.8644386f, 0.998915f};
+ const std::vector token_offset = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
+ int token_count = 16;
RunGatedRelativePositionBiasTest(query_layer, query_bias, rel_pos, weight, bias, eco_a, output,
- batch_size, seq_len, num_heads, head_size, D, true);
+ batch_size, seq_len, num_heads, head_size, D, true, token_offset, token_count);
}
} // namespace test