Fix GroupQueryAttention benchmark script (#20291)

### Description

Fix a few issues in GQA:
(1) memory efficient attention does not have bfloat16, need disable it
when bfloat16 is used.
(2) When prompt length is 1, it is not classified as prompt.
(3) Fix benchmark_gqa.py
(4) Add a comment about seqlen_k to avoid confusion.

### Motivation and Context
https://github.com/microsoft/onnxruntime/pull/20279
This commit is contained in:
Tianlei Wu 2024-05-08 09:48:46 -07:00 committed by GitHub
parent b6d9abf150
commit 1f509215bc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 362 additions and 282 deletions

View file

@ -45,7 +45,8 @@ GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info)
ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0 && num_heads % kv_num_heads == 0);
num_heads_ = static_cast<int>(num_heads);
kv_num_heads_ = static_cast<int>(kv_num_heads);
is_past_bsnh_ = false; // info.GetAttrOrDefault<int64_t>("is_past_bsnh", 1) == 1;
is_past_bsnh_ = false;
is_unidirectional_ = true;
local_window_size_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("local_window_size", -1));
do_rotary_ = info.GetAttrOrDefault<int64_t>("do_rotary", 0) == 1;
rotary_interleaved_ = info.GetAttrOrDefault<int64_t>("rotary_interleaved", 0) == 1;
@ -59,7 +60,8 @@ GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info)
#endif
#if USE_MEMORY_EFFICIENT_ATTENTION
disable_memory_efficient_attention_ = sizeof(T) != 2 ||
// Memory efficient attention only supports float and float16, not bfloat16.
disable_memory_efficient_attention_ = std::is_same<T, BFloat16>::value ||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableMemoryEfficientAttention, false);
#else
disable_memory_efficient_attention_ = true;
@ -160,7 +162,6 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
!disable_memory_efficient_attention_ &&
local_window_size_ == -1 &&
(parameters.head_size & 7) == 0 &&
parameters.sequence_length <= parameters.seqlen_past_kv_cache + parameters.sequence_length &&
(sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) &&
has_memory_efficient_attention(sm, sizeof(T) == 2);
if (!use_flash_attention && !use_memory_efficient_attention && local_window_size_ != -1) {

View file

@ -32,8 +32,6 @@ Status CheckInputs(const Tensor* query,
// query (Q) : (B, S, D) or (B, S, (D_q + 2 D_kv))
// key (K) : (B, S, D_kv) or nullptr
// value (V) : (B, S, D_kv) or nullptr
ORT_UNUSED_PARAMETER(value);
AttentionQkvFormat qkv_format = Q_K_V_BSNH;
AttentionQkvFormat past_kv_format = is_past_bsnh ? Q_K_V_BSNH : Q_K_V_BNSH;
const bool is_packed_qkv = key == nullptr;
@ -241,7 +239,11 @@ Status CheckInputs(const Tensor* query,
"Input 'cos_cache' and 'sin_cache' shall be both present or both absent.");
}
bool is_prompt = sequence_length != 1;
bool is_prompt = (sequence_length == total_sequence_length);
if (!is_prompt && sequence_length != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"sequence_length shall be 1 when it is not prompt.");
}
if (parameters != nullptr) {
GroupQueryAttentionParameters* output_parameters = reinterpret_cast<GroupQueryAttentionParameters*>(parameters);
@ -256,7 +258,6 @@ Status CheckInputs(const Tensor* query,
output_parameters->kv_num_heads = kv_num_heads;
output_parameters->rotary_dim = rotary_dim;
output_parameters->is_packed_qkv = is_packed_qkv;
output_parameters->is_unidirectional = true;
output_parameters->is_prompt = is_prompt;
output_parameters->scale = scale;
output_parameters->qkv_format = qkv_format;

View file

@ -632,7 +632,7 @@ Status FlashAttention(
const int kv_num_heads = parameters.kv_num_heads;
const int head_size = parameters.head_size;
AttentionQkvFormat past_kv_format = parameters.past_kv_format;
bool is_causal = true;
bool is_causal = parameters.is_unidirectional;
bool is_bf16 = std::is_same<T, BFloat16>::value;
void* query = reinterpret_cast<void*>(const_cast<T*>(data.query));

View file

@ -1104,6 +1104,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
OpSchema::Optional)
.Input(5,
"seqlens_k",
// For prompt, the value is number of tokens (excluding padding) - 1.
"1d Tensor of shape (batch_size). Indicates past sequence lengths for token generation case.",
"M")
.Input(6,

View file

@ -1,160 +1,207 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""
Benchmark performance of MultiHeadAttention with Nvidia GPU of Compute Capability 8.0, 8.6 or 8.9 in Linux:
sh benchmark_mha.sh
"""
import math
import random
import statistics
import time
from typing import Optional
import torch
from onnx import TensorProto, helper
from onnxruntime import InferenceSession, OrtValue, SessionOptions
from onnxruntime import InferenceSession, SessionOptions
from onnxruntime.transformers.io_binding_helper import CudaSession, GpuBindingManager
class InputFormats:
QKV_BSNH = 0
QKV_BNSH = 1
class AttentionConfig:
def __init__(
self,
operator: str,
batch_size: int,
sequence_length: int,
max_sequence_length: int,
past_sequence_length: int,
num_heads: int,
kv_num_heads: int,
head_size: int,
softmax_scale: Optional[float],
do_rotary: bool,
rotary_interleaved: bool,
device="cuda",
dtype=torch.float16,
share_buffer: bool = True,
is_packed_qkv: bool = False,
):
self.operator = operator
self.batch_size = batch_size
self.sequence_length = sequence_length
self.max_sequence_length = max_sequence_length
self.past_sequence_length = past_sequence_length
self.num_heads = num_heads
self.kv_num_heads = kv_num_heads
self.head_size = head_size
self.softmax_scale = softmax_scale if softmax_scale is not None else 1.0 / (head_size**0.5)
# Derived values
self.total_sequence_length = sequence_length + past_sequence_length
self.past_buffer_length = max_sequence_length if share_buffer else past_sequence_length
self.present_buffer_length = max_sequence_length if share_buffer else (past_sequence_length + sequence_length)
self.do_rotary = do_rotary
self.rotary_interleaved = rotary_interleaved
self.device = device
self.share_buffer = share_buffer
self.is_packed_qkv = is_packed_qkv
self.dtype = dtype
def shape_dict(self):
return {
"query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size),
"key": (self.batch_size, self.sequence_length, self.kv_num_heads * self.head_size),
"value": (self.batch_size, self.sequence_length, self.kv_num_heads * self.head_size),
"past_key": (self.batch_size, self.kv_num_heads, self.past_buffer_length, self.head_size),
"past_value": (self.batch_size, self.kv_num_heads, self.past_buffer_length, self.head_size),
"total_sequence_length": (1,),
"output": (self.batch_size, self.sequence_length, self.num_heads * self.head_size),
"present_key": (self.batch_size, self.kv_num_heads, self.present_buffer_length, self.head_size),
"present_value": (self.batch_size, self.kv_num_heads, self.present_buffer_length, self.head_size),
"cos_cache": (self.max_sequence_length, (math.floor(self.head_size / 16) * 16) // 2),
"sin_cache": (self.max_sequence_length, (math.floor(self.head_size / 16) * 16) // 2),
}
def get_cos_sin_cache(self, dtype):
rotary_fraction = 1.0
rotary_dim = math.floor(int(rotary_fraction * self.head_size) / 16) * 16
angle = torch.rand(self.max_sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi
cos = torch.cos(angle).to(dtype=dtype)
sin = torch.sin(angle).to(dtype=dtype)
return cos.to(device=self.device), sin.to(device=self.device)
def random_inputs(self):
device = self.device
# bfloat16 is not supported in ORT python I/O binding API
dtype = torch.float16
shape_dict = self.shape_dict()
torch.manual_seed(123)
feeds = {
"query": torch.empty(shape_dict["query"], device=device, dtype=dtype).normal_(mean=0, std=0.1),
"key": torch.empty(shape_dict["key"], device=device, dtype=dtype).normal_(mean=0, std=0.1),
"value": torch.empty(shape_dict["value"], device=device, dtype=dtype).normal_(mean=0, std=0.1),
"past_key": torch.empty(shape_dict["past_key"], device=device, dtype=dtype).normal_(mean=0, std=0.1),
"past_value": torch.empty(shape_dict["past_value"], device=device, dtype=dtype).normal_(mean=0, std=0.1),
"total_sequence_length": torch.tensor([self.total_sequence_length], dtype=torch.int32),
}
if self.do_rotary:
cos_cache, sin_cache = self.get_cos_sin_cache(dtype)
feeds["cos_cache"] = cos_cache
feeds["sin_cache"] = sin_cache
return feeds
class Config:
batch_size = 0
sequence_length = 0
kv_sequence_length = 0
past_sequence_length = 0
num_heads = 0
kv_num_heads = 0
head_size = 0
class GroupQueryAttentionConfig(AttentionConfig):
def __init__(
self,
batch_size: int,
sequence_length: int,
max_sequence_length: int,
past_sequence_length: int,
num_heads: int,
kv_num_heads: int,
head_size: int,
softmax_scale=None,
do_rotary: bool = False,
rotary_interleaved: bool = False,
device="cuda",
local_window_size: int = -1,
):
super().__init__(
"GroupQueryAttention",
batch_size,
sequence_length,
max_sequence_length,
past_sequence_length,
num_heads,
kv_num_heads,
head_size,
softmax_scale,
do_rotary,
rotary_interleaved,
device,
)
self.local_window_size = local_window_size
def __init__(self, b, s, s2, sp, n, n2, h):
self.batch_size = b
self.sequence_length = s
self.kv_sequence_length = s2
self.past_sequence_length = sp
self.num_heads = n
self.kv_num_heads = n2
self.head_size = h
def shape_dict(self):
shapes = super().shape_dict()
shapes.update(
{
"seqlens_k": (self.batch_size,),
}
)
return shapes
def random_inputs(self):
feeds = super().random_inputs()
k_seqlens = torch.ones((self.batch_size,), device=self.device, dtype=torch.int32) * self.total_sequence_length
feeds.update(
{
"seqlens_k": k_seqlens - 1,
}
)
return feeds
def create_group_query_attention_graph_past(
config, causal=False, past_kv_format=InputFormats.QKV_BSNH, share_buffer=True
):
past_kv_seqlen = config.kv_sequence_length if share_buffer else config.past_sequence_length
present_kv_seqlen = (
config.kv_sequence_length if share_buffer else config.past_sequence_length + config.sequence_length
)
def create_group_query_attention_onnx_model(config: GroupQueryAttentionConfig):
assert config.dtype == torch.float16
float_type = TensorProto.FLOAT16
nodes = [
helper.make_node(
"GroupQueryAttention",
[
"query",
"key",
"value",
"key" if not config.is_packed_qkv else "",
"value" if not config.is_packed_qkv else "",
"past_key",
"past_value",
"past_sequence_length" if share_buffer else "",
"seqlens_k",
"total_sequence_length" if config.share_buffer else "",
"cos_cache" if config.do_rotary else "",
"sin_cache" if config.do_rotary else "",
],
["output", "present_key", "present_value"],
"GroupQueryAttention_0",
num_heads=config.num_heads,
kv_num_heads=config.kv_num_heads,
unidirectional=1 if causal else 0,
is_past_bsnh=1 if past_kv_format == InputFormats.QKV_BSNH else 0,
scale=config.softmax_scale,
local_window_size=config.local_window_size,
do_rotary=1 if config.do_rotary else 0,
rotary_interleaved=config.rotary_interleaved,
domain="com.microsoft",
),
]
shape_dict = config.shape_dict()
graph_input = [
helper.make_tensor_value_info("query", float_type, list(shape_dict["query"])),
helper.make_tensor_value_info("key", float_type, list(shape_dict["key"])),
helper.make_tensor_value_info("value", float_type, list(shape_dict["value"])),
helper.make_tensor_value_info("past_key", float_type, list(shape_dict["past_key"])),
helper.make_tensor_value_info("past_value", float_type, list(shape_dict["past_value"])),
helper.make_tensor_value_info("seqlens_k", TensorProto.INT32, list(shape_dict["seqlens_k"])),
helper.make_tensor_value_info(
"query",
TensorProto.FLOAT16,
[
config.batch_size,
config.sequence_length,
config.num_heads * config.head_size,
],
),
helper.make_tensor_value_info(
"key",
TensorProto.FLOAT16,
[
config.batch_size,
config.sequence_length,
config.kv_num_heads * config.head_size,
],
),
helper.make_tensor_value_info(
"value",
TensorProto.FLOAT16,
[
config.batch_size,
config.sequence_length,
config.kv_num_heads * config.head_size,
],
),
helper.make_tensor_value_info(
"past_key",
TensorProto.FLOAT16,
[
config.batch_size,
past_kv_seqlen if past_kv_format == InputFormats.QKV_BSNH else config.kv_num_heads,
config.kv_num_heads if past_kv_format == InputFormats.QKV_BSNH else past_kv_seqlen,
config.head_size,
],
),
helper.make_tensor_value_info(
"past_value",
TensorProto.FLOAT16,
[
config.batch_size,
past_kv_seqlen if past_kv_format == InputFormats.QKV_BSNH else config.kv_num_heads,
config.kv_num_heads if past_kv_format == InputFormats.QKV_BSNH else past_kv_seqlen,
config.head_size,
],
"total_sequence_length", TensorProto.INT32, list(shape_dict["total_sequence_length"])
),
]
if share_buffer:
if config.do_rotary:
graph_input += [
helper.make_tensor_value_info(
"past_sequence_length",
TensorProto.INT32,
[1],
)
helper.make_tensor_value_info("cos_cache", float_type, list(shape_dict["cos_cache"])),
helper.make_tensor_value_info("sin_cache", float_type, list(shape_dict["sin_cache"])),
]
graph_output = [
helper.make_tensor_value_info(
"output",
TensorProto.FLOAT16,
[config.batch_size, config.sequence_length, config.num_heads * config.head_size],
),
helper.make_tensor_value_info(
"present_key",
TensorProto.FLOAT16,
[
config.batch_size,
present_kv_seqlen if past_kv_format == InputFormats.QKV_BSNH else config.kv_num_heads,
config.kv_num_heads if past_kv_format == InputFormats.QKV_BSNH else present_kv_seqlen,
config.head_size,
],
),
helper.make_tensor_value_info(
"present_value",
TensorProto.FLOAT16,
[
config.batch_size,
present_kv_seqlen if past_kv_format == InputFormats.QKV_BSNH else config.kv_num_heads,
config.kv_num_heads if past_kv_format == InputFormats.QKV_BSNH else present_kv_seqlen,
config.head_size,
],
),
helper.make_tensor_value_info("output", float_type, list(shape_dict["output"])),
helper.make_tensor_value_info("present_key", float_type, list(shape_dict["present_key"])),
helper.make_tensor_value_info("present_value", float_type, list(shape_dict["present_value"])),
]
graph = helper.make_graph(
@ -168,172 +215,202 @@ def create_group_query_attention_graph_past(
return model.SerializeToString()
def create_gqa_session(
config: Config,
causal: bool = False,
past_format=InputFormats.QKV_BSNH,
share_buffer: bool = True,
) -> InferenceSession:
onnx_model_str = create_group_query_attention_graph_past(config, causal, past_format, share_buffer)
sess_options = SessionOptions()
ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"])
def create_session(onnx_model_str, cuda_provider_options=None) -> InferenceSession:
session_options = SessionOptions()
ort_session = InferenceSession(
onnx_model_str,
session_options,
providers=[("CUDAExecutionProvider", cuda_provider_options), "CPUExecutionProvider"],
)
return ort_session
def bind_io(io_binding, input_dict, device, share_buffer=True):
io_binding.bind_cpu_input("query", input_dict["query"])
io_binding.bind_cpu_input("key", input_dict["key"])
io_binding.bind_cpu_input("value", input_dict["value"])
io_binding.bind_input(
"past_key", "cuda", 0, "float16", input_dict["past_key"].shape(), input_dict["past_key"].data_ptr()
)
io_binding.bind_input(
"past_value",
"cuda",
0,
"float16",
input_dict["past_value"].shape(),
input_dict["past_value"].data_ptr(),
)
io_binding.bind_output("output")
if share_buffer:
io_binding.bind_cpu_input("past_sequence_length", input_dict["past_sequence_length"])
io_binding.bind_output(
"present_key",
device_type="cuda",
device_id=device,
element_type="float16",
shape=input_dict["past_key"].shape(),
buffer_ptr=input_dict["past_key"].data_ptr(),
class OrtGroupQueryAttention:
"""A wrapper of ORT GroupQueryAttention to test relevance and performance."""
def __init__(self, config: GroupQueryAttentionConfig):
device = config.device
cuda_provider_options = CudaSession.get_cuda_provider_options(
torch.cuda.current_device(), enable_cuda_graph=False, stream=torch.cuda.current_stream().cuda_stream
)
io_binding.bind_output(
"present_value",
device_type="cuda",
device_id=device,
element_type="float16",
shape=input_dict["past_value"].shape(),
buffer_ptr=input_dict["past_value"].data_ptr(),
onnx_model_str = create_group_query_attention_onnx_model(config)
self.ort_session = create_session(onnx_model_str, cuda_provider_options=cuda_provider_options)
self.gpu_binding_manager = GpuBindingManager(
ort_session=self.ort_session,
device=device,
stream=torch.cuda.current_stream().cuda_stream,
max_cuda_graphs=2,
)
buffer_sharing = {"past_key": "present_key", "past_value": "present_value"}
self.gpu_binding = self.gpu_binding_manager.get_binding(
config.shape_dict(), use_cuda_graph=False, buffer_sharing=buffer_sharing
)
self.feed_dict = config.random_inputs()
def infer(self):
return self.gpu_binding.infer(self.feed_dict)
def get_plot_algos(sm: int):
# GQA with local windows only works in sm=8x
if sm >= 80:
return {
"line_vals": ["ort_gqa", "ort_gqa_local"],
"line_names": ["ORT-GQA-Dense", "ORT-GQA-Local"],
"styles": [("red", "-"), ("blue", "-")],
}
else:
io_binding.bind_output("present_key")
io_binding.bind_output("present_value")
return {
"line_vals": ["ort_gqa"],
"line_names": ["ORT-GQA-Dense"],
"styles": [("green", "-")],
}
def measure_latency(ort_session, io_binding):
start = time.time()
_ = ort_session.run_with_iobinding(io_binding)
end = time.time()
return end - start
def plot_prompt_performance(
sm: int,
batch_size=4,
num_heads=32,
kv_num_heads=8,
max_seq_len=8192,
head_size=128,
):
import triton
algos = get_plot_algos(sm)
configs = [
triton.testing.Benchmark(
x_names=["sequence_length"],
x_vals=[2**i for i in range(4, 14)],
line_arg="provider",
ylabel="ms",
**algos,
plot_name=f"prompt-sm{sm}-batch{batch_size}-head{num_heads}_kv{kv_num_heads}-d{head_size}-fp16",
args={
"num_heads": num_heads,
"kv_num_heads": kv_num_heads,
"batch_size": batch_size,
"head_size": head_size,
},
)
]
@triton.testing.perf_report(configs)
def benchmark(batch_size, num_heads, kv_num_heads, sequence_length, head_size, provider, device="cuda"):
warmup = 15
repeat = 100
config: GroupQueryAttentionConfig = GroupQueryAttentionConfig(
batch_size=batch_size,
sequence_length=sequence_length,
max_sequence_length=max_seq_len,
past_sequence_length=0,
num_heads=num_heads,
kv_num_heads=kv_num_heads,
head_size=head_size,
local_window_size=1024 if provider == "ort_gqa_local" else -1,
device=device,
)
obj = OrtGroupQueryAttention(config)
ms = triton.testing.do_bench(obj.infer, warmup=warmup, rep=repeat)
return ms
benchmark.run(save_path=".", print_data=True)
def flops(batch, q_seqlen, kv_seqlen, head_size, num_heads):
return 4 * batch * q_seqlen * kv_seqlen * num_heads * head_size
def plot_token_performance(
sm: int,
batch_size=4,
num_heads=32,
kv_num_heads=8,
max_seq_len=8192,
head_size=128,
):
import triton
algos = get_plot_algos(sm)
configs = [
triton.testing.Benchmark(
x_names=["past_sequence_length"],
x_vals=[2**i for i in range(4, 13)] + [max_seq_len - 1],
line_arg="provider",
ylabel="ms",
**algos,
plot_name=f"token-sm{sm}-batch{batch_size}-head{num_heads}_kv{kv_num_heads}-d{head_size}-fp16",
args={
"num_heads": num_heads,
"kv_num_heads": kv_num_heads,
"batch_size": batch_size,
"head_size": head_size,
},
)
]
@triton.testing.perf_report(configs)
def benchmark(
batch_size,
num_heads,
kv_num_heads,
past_sequence_length,
head_size,
provider,
device="cuda",
):
warmup = 15
repeat = 100
config: GroupQueryAttentionConfig = GroupQueryAttentionConfig(
batch_size=batch_size,
sequence_length=1,
max_sequence_length=max_seq_len,
past_sequence_length=past_sequence_length,
num_heads=num_heads,
kv_num_heads=kv_num_heads,
head_size=head_size,
local_window_size=1024 if provider == "ort_gqa_local" else -1,
device=device,
)
obj = OrtGroupQueryAttention(config)
ms = triton.testing.do_bench(obj.infer, warmup=warmup, rep=repeat)
return ms
benchmark.run(save_path=".", print_data=True)
def tflops_per_second(flop, time):
return (flop / time / 10**12) if not math.isnan(time) else 0.0
def run_performance_test(sm: int):
"""
Run performance tests for prompt and token generation.
def benchmark_op(session, io_binding, repeats=100):
# warm up session
_ = measure_latency(session, io_binding)
latency_list = []
for _ in range(repeats):
latency = measure_latency(session, io_binding)
latency_list.append(latency)
return statistics.mean(latency_list)
def run_tflops_test(dtype=torch.float16, repeats: int = 100):
device_id = torch.cuda.current_device()
device = torch.device("cuda", device_id)
print("---- GQA BSNH vs GQA BNSH ----")
print("op\tbatch\ts_kv\theads\th_dim\tms\tTFLOPS")
mean_bsnh_lat = 0
mean_bnsh_lat = 0
num_trials = 0
share_buffer = True
random.seed(69)
for b in [1, 3, 8, 16]:
for s_q, s_kv in [(1, 128), (128, 256), (512, 512), (128, 1024), (1, 2048)]:
for n_q, n_kv in [(8, 8), (16, 8), (32, 32), (12, 3), (128, 64)]:
for h in [32, 64, 128]:
sp = random.randint(1, s_kv - 1) if s_kv - 1 > 0 else 0
config = Config(b, s_q, s_kv, sp, n_q, n_kv, h)
bsnh_session = create_gqa_session(
config,
causal=False,
past_format=InputFormats.QKV_BSNH,
share_buffer=share_buffer,
)
bnsh_session = create_gqa_session(
config,
causal=False,
past_format=InputFormats.QKV_BNSH,
share_buffer=share_buffer,
)
q = torch.randn(b, s_q, n_q * h, device=device, dtype=dtype)
kv = torch.randn(b, s_q, 2, n_kv * h, device=device, dtype=dtype)
k, v = kv.unbind(dim=2)
past_kv = torch.rand(b, s_kv if share_buffer else sp, 2, n_kv, h, device=device, dtype=dtype)
past_k, past_v = past_kv.unbind(dim=2)
input_dict_bsnh = {
"query": q.detach().cpu().numpy(),
"key": k.detach().cpu().numpy(),
"value": v.detach().cpu().numpy(),
"past_key": OrtValue.ortvalue_from_numpy(past_k.detach().cpu().numpy(), "cuda", device_id),
"past_value": OrtValue.ortvalue_from_numpy(past_v.detach().cpu().numpy(), "cuda", device_id),
}
input_dict_bnsh = {
"query": q.detach().cpu().numpy(),
"key": k.detach().cpu().numpy(),
"value": v.detach().cpu().numpy(),
"past_key": OrtValue.ortvalue_from_numpy(
past_k.transpose(1, 2).detach().cpu().numpy(), "cuda", 0
),
"past_value": OrtValue.ortvalue_from_numpy(
past_v.transpose(1, 2).detach().cpu().numpy(), "cuda", 0
),
}
if share_buffer:
input_dict_bsnh["past_sequence_length"] = (
torch.tensor([sp], device="cuda", dtype=torch.int32).detach().cpu().numpy()
)
input_dict_bnsh["past_sequence_length"] = (
torch.tensor([sp], device="cuda", dtype=torch.int32).detach().cpu().numpy()
)
io_binding_bsnh = bsnh_session.io_binding()
io_binding_bnsh = bnsh_session.io_binding()
bind_io(io_binding_bsnh, input_dict_bsnh, device_id, share_buffer)
bind_io(io_binding_bnsh, input_dict_bnsh, device_id, share_buffer)
average_gqa_bsnh_latency = benchmark_op(bsnh_session, io_binding_bsnh, repeats)
average_gqa_bnsh_latency = benchmark_op(bnsh_session, io_binding_bnsh, repeats)
del bsnh_session
del bnsh_session
# compute TFLOPS per second
bsnh_speed = tflops_per_second(flops(b, s_q, s_kv, h, n_q), average_gqa_bsnh_latency)
print(f"bsnh\t{b}\t{s_kv}\t{n_q}\t{h}\t{average_gqa_bsnh_latency * 1000:.2f}\t{bsnh_speed:.2f}")
bnsh_speed = tflops_per_second(flops(b, s_q, s_kv, h, n_q), average_gqa_bnsh_latency)
print(f"bnsh\t{b}\t{s_kv}\t{n_q}\t{h}\t{average_gqa_bnsh_latency * 1000:.2f}\t{bnsh_speed:.2f}")
print("---------")
if average_gqa_bsnh_latency > 10 * average_gqa_bnsh_latency:
continue
num_trials += 1
mean_bsnh_lat += average_gqa_bsnh_latency
mean_bnsh_lat += average_gqa_bnsh_latency
mean_bsnh_lat /= num_trials
mean_bnsh_lat /= num_trials
print(f"average bsnh latency:\t{mean_bsnh_lat}")
print(f"average bnsh latency:\t{mean_bnsh_lat}")
"""
for batch_size in [1, 4, 8, 16]:
for num_heads, kv_num_heads in [(8, 8), (16, 8), (32, 8), (64, 8)]:
for head_size in [64, 128]:
plot_prompt_performance(
sm=sm,
batch_size=batch_size,
num_heads=num_heads,
kv_num_heads=kv_num_heads,
max_seq_len=8192,
head_size=head_size,
)
plot_token_performance(
sm=sm,
batch_size=batch_size,
num_heads=num_heads,
kv_num_heads=kv_num_heads,
max_seq_len=8192,
head_size=head_size,
)
if __name__ == "__main__":
run_tflops_test()
major, minor = torch.cuda.get_device_capability()
sm = major * 10 + minor
s = torch.cuda.Stream()
with torch.cuda.stream(s), torch.no_grad():
run_performance_test(sm)