mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-01 03:45:06 +00:00
Fix GroupQueryAttention benchmark script (#20291)
### Description Fix a few issues in GQA: (1) memory efficient attention does not have bfloat16, need disable it when bfloat16 is used. (2) When prompt length is 1, it is not classified as prompt. (3) Fix benchmark_gqa.py (4) Add a comment about seqlen_k to avoid confusion. ### Motivation and Context https://github.com/microsoft/onnxruntime/pull/20279
This commit is contained in:
parent
b6d9abf150
commit
1f509215bc
5 changed files with 362 additions and 282 deletions
|
|
@ -45,7 +45,8 @@ GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info)
|
|||
ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0 && num_heads % kv_num_heads == 0);
|
||||
num_heads_ = static_cast<int>(num_heads);
|
||||
kv_num_heads_ = static_cast<int>(kv_num_heads);
|
||||
is_past_bsnh_ = false; // info.GetAttrOrDefault<int64_t>("is_past_bsnh", 1) == 1;
|
||||
is_past_bsnh_ = false;
|
||||
is_unidirectional_ = true;
|
||||
local_window_size_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("local_window_size", -1));
|
||||
do_rotary_ = info.GetAttrOrDefault<int64_t>("do_rotary", 0) == 1;
|
||||
rotary_interleaved_ = info.GetAttrOrDefault<int64_t>("rotary_interleaved", 0) == 1;
|
||||
|
|
@ -59,7 +60,8 @@ GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info)
|
|||
#endif
|
||||
|
||||
#if USE_MEMORY_EFFICIENT_ATTENTION
|
||||
disable_memory_efficient_attention_ = sizeof(T) != 2 ||
|
||||
// Memory efficient attention only supports float and float16, not bfloat16.
|
||||
disable_memory_efficient_attention_ = std::is_same<T, BFloat16>::value ||
|
||||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableMemoryEfficientAttention, false);
|
||||
#else
|
||||
disable_memory_efficient_attention_ = true;
|
||||
|
|
@ -160,7 +162,6 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
!disable_memory_efficient_attention_ &&
|
||||
local_window_size_ == -1 &&
|
||||
(parameters.head_size & 7) == 0 &&
|
||||
parameters.sequence_length <= parameters.seqlen_past_kv_cache + parameters.sequence_length &&
|
||||
(sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) &&
|
||||
has_memory_efficient_attention(sm, sizeof(T) == 2);
|
||||
if (!use_flash_attention && !use_memory_efficient_attention && local_window_size_ != -1) {
|
||||
|
|
|
|||
|
|
@ -32,8 +32,6 @@ Status CheckInputs(const Tensor* query,
|
|||
// query (Q) : (B, S, D) or (B, S, (D_q + 2 D_kv))
|
||||
// key (K) : (B, S, D_kv) or nullptr
|
||||
// value (V) : (B, S, D_kv) or nullptr
|
||||
ORT_UNUSED_PARAMETER(value);
|
||||
|
||||
AttentionQkvFormat qkv_format = Q_K_V_BSNH;
|
||||
AttentionQkvFormat past_kv_format = is_past_bsnh ? Q_K_V_BSNH : Q_K_V_BNSH;
|
||||
const bool is_packed_qkv = key == nullptr;
|
||||
|
|
@ -241,7 +239,11 @@ Status CheckInputs(const Tensor* query,
|
|||
"Input 'cos_cache' and 'sin_cache' shall be both present or both absent.");
|
||||
}
|
||||
|
||||
bool is_prompt = sequence_length != 1;
|
||||
bool is_prompt = (sequence_length == total_sequence_length);
|
||||
if (!is_prompt && sequence_length != 1) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"sequence_length shall be 1 when it is not prompt.");
|
||||
}
|
||||
|
||||
if (parameters != nullptr) {
|
||||
GroupQueryAttentionParameters* output_parameters = reinterpret_cast<GroupQueryAttentionParameters*>(parameters);
|
||||
|
|
@ -256,7 +258,6 @@ Status CheckInputs(const Tensor* query,
|
|||
output_parameters->kv_num_heads = kv_num_heads;
|
||||
output_parameters->rotary_dim = rotary_dim;
|
||||
output_parameters->is_packed_qkv = is_packed_qkv;
|
||||
output_parameters->is_unidirectional = true;
|
||||
output_parameters->is_prompt = is_prompt;
|
||||
output_parameters->scale = scale;
|
||||
output_parameters->qkv_format = qkv_format;
|
||||
|
|
|
|||
|
|
@ -632,7 +632,7 @@ Status FlashAttention(
|
|||
const int kv_num_heads = parameters.kv_num_heads;
|
||||
const int head_size = parameters.head_size;
|
||||
AttentionQkvFormat past_kv_format = parameters.past_kv_format;
|
||||
bool is_causal = true;
|
||||
bool is_causal = parameters.is_unidirectional;
|
||||
bool is_bf16 = std::is_same<T, BFloat16>::value;
|
||||
|
||||
void* query = reinterpret_cast<void*>(const_cast<T*>(data.query));
|
||||
|
|
|
|||
|
|
@ -1104,6 +1104,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
|
|||
OpSchema::Optional)
|
||||
.Input(5,
|
||||
"seqlens_k",
|
||||
// For prompt, the value is number of tokens (excluding padding) - 1.
|
||||
"1d Tensor of shape (batch_size). Indicates past sequence lengths for token generation case.",
|
||||
"M")
|
||||
.Input(6,
|
||||
|
|
|
|||
|
|
@ -1,160 +1,207 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
"""
|
||||
Benchmark performance of MultiHeadAttention with Nvidia GPU of Compute Capability 8.0, 8.6 or 8.9 in Linux:
|
||||
sh benchmark_mha.sh
|
||||
"""
|
||||
|
||||
import math
|
||||
import random
|
||||
import statistics
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from onnx import TensorProto, helper
|
||||
|
||||
from onnxruntime import InferenceSession, OrtValue, SessionOptions
|
||||
from onnxruntime import InferenceSession, SessionOptions
|
||||
from onnxruntime.transformers.io_binding_helper import CudaSession, GpuBindingManager
|
||||
|
||||
|
||||
class InputFormats:
|
||||
QKV_BSNH = 0
|
||||
QKV_BNSH = 1
|
||||
class AttentionConfig:
|
||||
def __init__(
|
||||
self,
|
||||
operator: str,
|
||||
batch_size: int,
|
||||
sequence_length: int,
|
||||
max_sequence_length: int,
|
||||
past_sequence_length: int,
|
||||
num_heads: int,
|
||||
kv_num_heads: int,
|
||||
head_size: int,
|
||||
softmax_scale: Optional[float],
|
||||
do_rotary: bool,
|
||||
rotary_interleaved: bool,
|
||||
device="cuda",
|
||||
dtype=torch.float16,
|
||||
share_buffer: bool = True,
|
||||
is_packed_qkv: bool = False,
|
||||
):
|
||||
self.operator = operator
|
||||
self.batch_size = batch_size
|
||||
self.sequence_length = sequence_length
|
||||
self.max_sequence_length = max_sequence_length
|
||||
self.past_sequence_length = past_sequence_length
|
||||
self.num_heads = num_heads
|
||||
self.kv_num_heads = kv_num_heads
|
||||
self.head_size = head_size
|
||||
self.softmax_scale = softmax_scale if softmax_scale is not None else 1.0 / (head_size**0.5)
|
||||
|
||||
# Derived values
|
||||
self.total_sequence_length = sequence_length + past_sequence_length
|
||||
self.past_buffer_length = max_sequence_length if share_buffer else past_sequence_length
|
||||
self.present_buffer_length = max_sequence_length if share_buffer else (past_sequence_length + sequence_length)
|
||||
|
||||
self.do_rotary = do_rotary
|
||||
self.rotary_interleaved = rotary_interleaved
|
||||
self.device = device
|
||||
|
||||
self.share_buffer = share_buffer
|
||||
self.is_packed_qkv = is_packed_qkv
|
||||
self.dtype = dtype
|
||||
|
||||
def shape_dict(self):
|
||||
return {
|
||||
"query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size),
|
||||
"key": (self.batch_size, self.sequence_length, self.kv_num_heads * self.head_size),
|
||||
"value": (self.batch_size, self.sequence_length, self.kv_num_heads * self.head_size),
|
||||
"past_key": (self.batch_size, self.kv_num_heads, self.past_buffer_length, self.head_size),
|
||||
"past_value": (self.batch_size, self.kv_num_heads, self.past_buffer_length, self.head_size),
|
||||
"total_sequence_length": (1,),
|
||||
"output": (self.batch_size, self.sequence_length, self.num_heads * self.head_size),
|
||||
"present_key": (self.batch_size, self.kv_num_heads, self.present_buffer_length, self.head_size),
|
||||
"present_value": (self.batch_size, self.kv_num_heads, self.present_buffer_length, self.head_size),
|
||||
"cos_cache": (self.max_sequence_length, (math.floor(self.head_size / 16) * 16) // 2),
|
||||
"sin_cache": (self.max_sequence_length, (math.floor(self.head_size / 16) * 16) // 2),
|
||||
}
|
||||
|
||||
def get_cos_sin_cache(self, dtype):
|
||||
rotary_fraction = 1.0
|
||||
rotary_dim = math.floor(int(rotary_fraction * self.head_size) / 16) * 16
|
||||
angle = torch.rand(self.max_sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi
|
||||
cos = torch.cos(angle).to(dtype=dtype)
|
||||
sin = torch.sin(angle).to(dtype=dtype)
|
||||
return cos.to(device=self.device), sin.to(device=self.device)
|
||||
|
||||
def random_inputs(self):
|
||||
device = self.device
|
||||
# bfloat16 is not supported in ORT python I/O binding API
|
||||
dtype = torch.float16
|
||||
shape_dict = self.shape_dict()
|
||||
|
||||
torch.manual_seed(123)
|
||||
feeds = {
|
||||
"query": torch.empty(shape_dict["query"], device=device, dtype=dtype).normal_(mean=0, std=0.1),
|
||||
"key": torch.empty(shape_dict["key"], device=device, dtype=dtype).normal_(mean=0, std=0.1),
|
||||
"value": torch.empty(shape_dict["value"], device=device, dtype=dtype).normal_(mean=0, std=0.1),
|
||||
"past_key": torch.empty(shape_dict["past_key"], device=device, dtype=dtype).normal_(mean=0, std=0.1),
|
||||
"past_value": torch.empty(shape_dict["past_value"], device=device, dtype=dtype).normal_(mean=0, std=0.1),
|
||||
"total_sequence_length": torch.tensor([self.total_sequence_length], dtype=torch.int32),
|
||||
}
|
||||
|
||||
if self.do_rotary:
|
||||
cos_cache, sin_cache = self.get_cos_sin_cache(dtype)
|
||||
feeds["cos_cache"] = cos_cache
|
||||
feeds["sin_cache"] = sin_cache
|
||||
|
||||
return feeds
|
||||
|
||||
|
||||
class Config:
|
||||
batch_size = 0
|
||||
sequence_length = 0
|
||||
kv_sequence_length = 0
|
||||
past_sequence_length = 0
|
||||
num_heads = 0
|
||||
kv_num_heads = 0
|
||||
head_size = 0
|
||||
class GroupQueryAttentionConfig(AttentionConfig):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int,
|
||||
sequence_length: int,
|
||||
max_sequence_length: int,
|
||||
past_sequence_length: int,
|
||||
num_heads: int,
|
||||
kv_num_heads: int,
|
||||
head_size: int,
|
||||
softmax_scale=None,
|
||||
do_rotary: bool = False,
|
||||
rotary_interleaved: bool = False,
|
||||
device="cuda",
|
||||
local_window_size: int = -1,
|
||||
):
|
||||
super().__init__(
|
||||
"GroupQueryAttention",
|
||||
batch_size,
|
||||
sequence_length,
|
||||
max_sequence_length,
|
||||
past_sequence_length,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
head_size,
|
||||
softmax_scale,
|
||||
do_rotary,
|
||||
rotary_interleaved,
|
||||
device,
|
||||
)
|
||||
self.local_window_size = local_window_size
|
||||
|
||||
def __init__(self, b, s, s2, sp, n, n2, h):
|
||||
self.batch_size = b
|
||||
self.sequence_length = s
|
||||
self.kv_sequence_length = s2
|
||||
self.past_sequence_length = sp
|
||||
self.num_heads = n
|
||||
self.kv_num_heads = n2
|
||||
self.head_size = h
|
||||
def shape_dict(self):
|
||||
shapes = super().shape_dict()
|
||||
shapes.update(
|
||||
{
|
||||
"seqlens_k": (self.batch_size,),
|
||||
}
|
||||
)
|
||||
return shapes
|
||||
|
||||
def random_inputs(self):
|
||||
feeds = super().random_inputs()
|
||||
k_seqlens = torch.ones((self.batch_size,), device=self.device, dtype=torch.int32) * self.total_sequence_length
|
||||
feeds.update(
|
||||
{
|
||||
"seqlens_k": k_seqlens - 1,
|
||||
}
|
||||
)
|
||||
return feeds
|
||||
|
||||
|
||||
def create_group_query_attention_graph_past(
|
||||
config, causal=False, past_kv_format=InputFormats.QKV_BSNH, share_buffer=True
|
||||
):
|
||||
past_kv_seqlen = config.kv_sequence_length if share_buffer else config.past_sequence_length
|
||||
present_kv_seqlen = (
|
||||
config.kv_sequence_length if share_buffer else config.past_sequence_length + config.sequence_length
|
||||
)
|
||||
def create_group_query_attention_onnx_model(config: GroupQueryAttentionConfig):
|
||||
assert config.dtype == torch.float16
|
||||
|
||||
float_type = TensorProto.FLOAT16
|
||||
nodes = [
|
||||
helper.make_node(
|
||||
"GroupQueryAttention",
|
||||
[
|
||||
"query",
|
||||
"key",
|
||||
"value",
|
||||
"key" if not config.is_packed_qkv else "",
|
||||
"value" if not config.is_packed_qkv else "",
|
||||
"past_key",
|
||||
"past_value",
|
||||
"past_sequence_length" if share_buffer else "",
|
||||
"seqlens_k",
|
||||
"total_sequence_length" if config.share_buffer else "",
|
||||
"cos_cache" if config.do_rotary else "",
|
||||
"sin_cache" if config.do_rotary else "",
|
||||
],
|
||||
["output", "present_key", "present_value"],
|
||||
"GroupQueryAttention_0",
|
||||
num_heads=config.num_heads,
|
||||
kv_num_heads=config.kv_num_heads,
|
||||
unidirectional=1 if causal else 0,
|
||||
is_past_bsnh=1 if past_kv_format == InputFormats.QKV_BSNH else 0,
|
||||
scale=config.softmax_scale,
|
||||
local_window_size=config.local_window_size,
|
||||
do_rotary=1 if config.do_rotary else 0,
|
||||
rotary_interleaved=config.rotary_interleaved,
|
||||
domain="com.microsoft",
|
||||
),
|
||||
]
|
||||
|
||||
shape_dict = config.shape_dict()
|
||||
graph_input = [
|
||||
helper.make_tensor_value_info("query", float_type, list(shape_dict["query"])),
|
||||
helper.make_tensor_value_info("key", float_type, list(shape_dict["key"])),
|
||||
helper.make_tensor_value_info("value", float_type, list(shape_dict["value"])),
|
||||
helper.make_tensor_value_info("past_key", float_type, list(shape_dict["past_key"])),
|
||||
helper.make_tensor_value_info("past_value", float_type, list(shape_dict["past_value"])),
|
||||
helper.make_tensor_value_info("seqlens_k", TensorProto.INT32, list(shape_dict["seqlens_k"])),
|
||||
helper.make_tensor_value_info(
|
||||
"query",
|
||||
TensorProto.FLOAT16,
|
||||
[
|
||||
config.batch_size,
|
||||
config.sequence_length,
|
||||
config.num_heads * config.head_size,
|
||||
],
|
||||
),
|
||||
helper.make_tensor_value_info(
|
||||
"key",
|
||||
TensorProto.FLOAT16,
|
||||
[
|
||||
config.batch_size,
|
||||
config.sequence_length,
|
||||
config.kv_num_heads * config.head_size,
|
||||
],
|
||||
),
|
||||
helper.make_tensor_value_info(
|
||||
"value",
|
||||
TensorProto.FLOAT16,
|
||||
[
|
||||
config.batch_size,
|
||||
config.sequence_length,
|
||||
config.kv_num_heads * config.head_size,
|
||||
],
|
||||
),
|
||||
helper.make_tensor_value_info(
|
||||
"past_key",
|
||||
TensorProto.FLOAT16,
|
||||
[
|
||||
config.batch_size,
|
||||
past_kv_seqlen if past_kv_format == InputFormats.QKV_BSNH else config.kv_num_heads,
|
||||
config.kv_num_heads if past_kv_format == InputFormats.QKV_BSNH else past_kv_seqlen,
|
||||
config.head_size,
|
||||
],
|
||||
),
|
||||
helper.make_tensor_value_info(
|
||||
"past_value",
|
||||
TensorProto.FLOAT16,
|
||||
[
|
||||
config.batch_size,
|
||||
past_kv_seqlen if past_kv_format == InputFormats.QKV_BSNH else config.kv_num_heads,
|
||||
config.kv_num_heads if past_kv_format == InputFormats.QKV_BSNH else past_kv_seqlen,
|
||||
config.head_size,
|
||||
],
|
||||
"total_sequence_length", TensorProto.INT32, list(shape_dict["total_sequence_length"])
|
||||
),
|
||||
]
|
||||
if share_buffer:
|
||||
|
||||
if config.do_rotary:
|
||||
graph_input += [
|
||||
helper.make_tensor_value_info(
|
||||
"past_sequence_length",
|
||||
TensorProto.INT32,
|
||||
[1],
|
||||
)
|
||||
helper.make_tensor_value_info("cos_cache", float_type, list(shape_dict["cos_cache"])),
|
||||
helper.make_tensor_value_info("sin_cache", float_type, list(shape_dict["sin_cache"])),
|
||||
]
|
||||
|
||||
graph_output = [
|
||||
helper.make_tensor_value_info(
|
||||
"output",
|
||||
TensorProto.FLOAT16,
|
||||
[config.batch_size, config.sequence_length, config.num_heads * config.head_size],
|
||||
),
|
||||
helper.make_tensor_value_info(
|
||||
"present_key",
|
||||
TensorProto.FLOAT16,
|
||||
[
|
||||
config.batch_size,
|
||||
present_kv_seqlen if past_kv_format == InputFormats.QKV_BSNH else config.kv_num_heads,
|
||||
config.kv_num_heads if past_kv_format == InputFormats.QKV_BSNH else present_kv_seqlen,
|
||||
config.head_size,
|
||||
],
|
||||
),
|
||||
helper.make_tensor_value_info(
|
||||
"present_value",
|
||||
TensorProto.FLOAT16,
|
||||
[
|
||||
config.batch_size,
|
||||
present_kv_seqlen if past_kv_format == InputFormats.QKV_BSNH else config.kv_num_heads,
|
||||
config.kv_num_heads if past_kv_format == InputFormats.QKV_BSNH else present_kv_seqlen,
|
||||
config.head_size,
|
||||
],
|
||||
),
|
||||
helper.make_tensor_value_info("output", float_type, list(shape_dict["output"])),
|
||||
helper.make_tensor_value_info("present_key", float_type, list(shape_dict["present_key"])),
|
||||
helper.make_tensor_value_info("present_value", float_type, list(shape_dict["present_value"])),
|
||||
]
|
||||
|
||||
graph = helper.make_graph(
|
||||
|
|
@ -168,172 +215,202 @@ def create_group_query_attention_graph_past(
|
|||
return model.SerializeToString()
|
||||
|
||||
|
||||
def create_gqa_session(
|
||||
config: Config,
|
||||
causal: bool = False,
|
||||
past_format=InputFormats.QKV_BSNH,
|
||||
share_buffer: bool = True,
|
||||
) -> InferenceSession:
|
||||
onnx_model_str = create_group_query_attention_graph_past(config, causal, past_format, share_buffer)
|
||||
sess_options = SessionOptions()
|
||||
ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"])
|
||||
def create_session(onnx_model_str, cuda_provider_options=None) -> InferenceSession:
|
||||
session_options = SessionOptions()
|
||||
ort_session = InferenceSession(
|
||||
onnx_model_str,
|
||||
session_options,
|
||||
providers=[("CUDAExecutionProvider", cuda_provider_options), "CPUExecutionProvider"],
|
||||
)
|
||||
return ort_session
|
||||
|
||||
|
||||
def bind_io(io_binding, input_dict, device, share_buffer=True):
|
||||
io_binding.bind_cpu_input("query", input_dict["query"])
|
||||
io_binding.bind_cpu_input("key", input_dict["key"])
|
||||
io_binding.bind_cpu_input("value", input_dict["value"])
|
||||
io_binding.bind_input(
|
||||
"past_key", "cuda", 0, "float16", input_dict["past_key"].shape(), input_dict["past_key"].data_ptr()
|
||||
)
|
||||
io_binding.bind_input(
|
||||
"past_value",
|
||||
"cuda",
|
||||
0,
|
||||
"float16",
|
||||
input_dict["past_value"].shape(),
|
||||
input_dict["past_value"].data_ptr(),
|
||||
)
|
||||
io_binding.bind_output("output")
|
||||
if share_buffer:
|
||||
io_binding.bind_cpu_input("past_sequence_length", input_dict["past_sequence_length"])
|
||||
io_binding.bind_output(
|
||||
"present_key",
|
||||
device_type="cuda",
|
||||
device_id=device,
|
||||
element_type="float16",
|
||||
shape=input_dict["past_key"].shape(),
|
||||
buffer_ptr=input_dict["past_key"].data_ptr(),
|
||||
class OrtGroupQueryAttention:
|
||||
"""A wrapper of ORT GroupQueryAttention to test relevance and performance."""
|
||||
|
||||
def __init__(self, config: GroupQueryAttentionConfig):
|
||||
device = config.device
|
||||
cuda_provider_options = CudaSession.get_cuda_provider_options(
|
||||
torch.cuda.current_device(), enable_cuda_graph=False, stream=torch.cuda.current_stream().cuda_stream
|
||||
)
|
||||
io_binding.bind_output(
|
||||
"present_value",
|
||||
device_type="cuda",
|
||||
device_id=device,
|
||||
element_type="float16",
|
||||
shape=input_dict["past_value"].shape(),
|
||||
buffer_ptr=input_dict["past_value"].data_ptr(),
|
||||
onnx_model_str = create_group_query_attention_onnx_model(config)
|
||||
self.ort_session = create_session(onnx_model_str, cuda_provider_options=cuda_provider_options)
|
||||
self.gpu_binding_manager = GpuBindingManager(
|
||||
ort_session=self.ort_session,
|
||||
device=device,
|
||||
stream=torch.cuda.current_stream().cuda_stream,
|
||||
max_cuda_graphs=2,
|
||||
)
|
||||
buffer_sharing = {"past_key": "present_key", "past_value": "present_value"}
|
||||
self.gpu_binding = self.gpu_binding_manager.get_binding(
|
||||
config.shape_dict(), use_cuda_graph=False, buffer_sharing=buffer_sharing
|
||||
)
|
||||
self.feed_dict = config.random_inputs()
|
||||
|
||||
def infer(self):
|
||||
return self.gpu_binding.infer(self.feed_dict)
|
||||
|
||||
|
||||
def get_plot_algos(sm: int):
|
||||
# GQA with local windows only works in sm=8x
|
||||
if sm >= 80:
|
||||
return {
|
||||
"line_vals": ["ort_gqa", "ort_gqa_local"],
|
||||
"line_names": ["ORT-GQA-Dense", "ORT-GQA-Local"],
|
||||
"styles": [("red", "-"), ("blue", "-")],
|
||||
}
|
||||
else:
|
||||
io_binding.bind_output("present_key")
|
||||
io_binding.bind_output("present_value")
|
||||
return {
|
||||
"line_vals": ["ort_gqa"],
|
||||
"line_names": ["ORT-GQA-Dense"],
|
||||
"styles": [("green", "-")],
|
||||
}
|
||||
|
||||
|
||||
def measure_latency(ort_session, io_binding):
|
||||
start = time.time()
|
||||
_ = ort_session.run_with_iobinding(io_binding)
|
||||
end = time.time()
|
||||
return end - start
|
||||
def plot_prompt_performance(
|
||||
sm: int,
|
||||
batch_size=4,
|
||||
num_heads=32,
|
||||
kv_num_heads=8,
|
||||
max_seq_len=8192,
|
||||
head_size=128,
|
||||
):
|
||||
import triton
|
||||
|
||||
algos = get_plot_algos(sm)
|
||||
configs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names=["sequence_length"],
|
||||
x_vals=[2**i for i in range(4, 14)],
|
||||
line_arg="provider",
|
||||
ylabel="ms",
|
||||
**algos,
|
||||
plot_name=f"prompt-sm{sm}-batch{batch_size}-head{num_heads}_kv{kv_num_heads}-d{head_size}-fp16",
|
||||
args={
|
||||
"num_heads": num_heads,
|
||||
"kv_num_heads": kv_num_heads,
|
||||
"batch_size": batch_size,
|
||||
"head_size": head_size,
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
def benchmark(batch_size, num_heads, kv_num_heads, sequence_length, head_size, provider, device="cuda"):
|
||||
warmup = 15
|
||||
repeat = 100
|
||||
|
||||
config: GroupQueryAttentionConfig = GroupQueryAttentionConfig(
|
||||
batch_size=batch_size,
|
||||
sequence_length=sequence_length,
|
||||
max_sequence_length=max_seq_len,
|
||||
past_sequence_length=0,
|
||||
num_heads=num_heads,
|
||||
kv_num_heads=kv_num_heads,
|
||||
head_size=head_size,
|
||||
local_window_size=1024 if provider == "ort_gqa_local" else -1,
|
||||
device=device,
|
||||
)
|
||||
|
||||
obj = OrtGroupQueryAttention(config)
|
||||
|
||||
ms = triton.testing.do_bench(obj.infer, warmup=warmup, rep=repeat)
|
||||
return ms
|
||||
|
||||
benchmark.run(save_path=".", print_data=True)
|
||||
|
||||
|
||||
def flops(batch, q_seqlen, kv_seqlen, head_size, num_heads):
|
||||
return 4 * batch * q_seqlen * kv_seqlen * num_heads * head_size
|
||||
def plot_token_performance(
|
||||
sm: int,
|
||||
batch_size=4,
|
||||
num_heads=32,
|
||||
kv_num_heads=8,
|
||||
max_seq_len=8192,
|
||||
head_size=128,
|
||||
):
|
||||
import triton
|
||||
|
||||
algos = get_plot_algos(sm)
|
||||
configs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names=["past_sequence_length"],
|
||||
x_vals=[2**i for i in range(4, 13)] + [max_seq_len - 1],
|
||||
line_arg="provider",
|
||||
ylabel="ms",
|
||||
**algos,
|
||||
plot_name=f"token-sm{sm}-batch{batch_size}-head{num_heads}_kv{kv_num_heads}-d{head_size}-fp16",
|
||||
args={
|
||||
"num_heads": num_heads,
|
||||
"kv_num_heads": kv_num_heads,
|
||||
"batch_size": batch_size,
|
||||
"head_size": head_size,
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
def benchmark(
|
||||
batch_size,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
past_sequence_length,
|
||||
head_size,
|
||||
provider,
|
||||
device="cuda",
|
||||
):
|
||||
warmup = 15
|
||||
repeat = 100
|
||||
|
||||
config: GroupQueryAttentionConfig = GroupQueryAttentionConfig(
|
||||
batch_size=batch_size,
|
||||
sequence_length=1,
|
||||
max_sequence_length=max_seq_len,
|
||||
past_sequence_length=past_sequence_length,
|
||||
num_heads=num_heads,
|
||||
kv_num_heads=kv_num_heads,
|
||||
head_size=head_size,
|
||||
local_window_size=1024 if provider == "ort_gqa_local" else -1,
|
||||
device=device,
|
||||
)
|
||||
|
||||
obj = OrtGroupQueryAttention(config)
|
||||
|
||||
ms = triton.testing.do_bench(obj.infer, warmup=warmup, rep=repeat)
|
||||
return ms
|
||||
|
||||
benchmark.run(save_path=".", print_data=True)
|
||||
|
||||
|
||||
def tflops_per_second(flop, time):
|
||||
return (flop / time / 10**12) if not math.isnan(time) else 0.0
|
||||
def run_performance_test(sm: int):
|
||||
"""
|
||||
Run performance tests for prompt and token generation.
|
||||
|
||||
|
||||
def benchmark_op(session, io_binding, repeats=100):
|
||||
# warm up session
|
||||
_ = measure_latency(session, io_binding)
|
||||
|
||||
latency_list = []
|
||||
for _ in range(repeats):
|
||||
latency = measure_latency(session, io_binding)
|
||||
latency_list.append(latency)
|
||||
return statistics.mean(latency_list)
|
||||
|
||||
|
||||
def run_tflops_test(dtype=torch.float16, repeats: int = 100):
|
||||
device_id = torch.cuda.current_device()
|
||||
device = torch.device("cuda", device_id)
|
||||
print("---- GQA BSNH vs GQA BNSH ----")
|
||||
print("op\tbatch\ts_kv\theads\th_dim\tms\tTFLOPS")
|
||||
mean_bsnh_lat = 0
|
||||
mean_bnsh_lat = 0
|
||||
num_trials = 0
|
||||
share_buffer = True
|
||||
random.seed(69)
|
||||
for b in [1, 3, 8, 16]:
|
||||
for s_q, s_kv in [(1, 128), (128, 256), (512, 512), (128, 1024), (1, 2048)]:
|
||||
for n_q, n_kv in [(8, 8), (16, 8), (32, 32), (12, 3), (128, 64)]:
|
||||
for h in [32, 64, 128]:
|
||||
sp = random.randint(1, s_kv - 1) if s_kv - 1 > 0 else 0
|
||||
config = Config(b, s_q, s_kv, sp, n_q, n_kv, h)
|
||||
|
||||
bsnh_session = create_gqa_session(
|
||||
config,
|
||||
causal=False,
|
||||
past_format=InputFormats.QKV_BSNH,
|
||||
share_buffer=share_buffer,
|
||||
)
|
||||
bnsh_session = create_gqa_session(
|
||||
config,
|
||||
causal=False,
|
||||
past_format=InputFormats.QKV_BNSH,
|
||||
share_buffer=share_buffer,
|
||||
)
|
||||
|
||||
q = torch.randn(b, s_q, n_q * h, device=device, dtype=dtype)
|
||||
kv = torch.randn(b, s_q, 2, n_kv * h, device=device, dtype=dtype)
|
||||
k, v = kv.unbind(dim=2)
|
||||
|
||||
past_kv = torch.rand(b, s_kv if share_buffer else sp, 2, n_kv, h, device=device, dtype=dtype)
|
||||
past_k, past_v = past_kv.unbind(dim=2)
|
||||
|
||||
input_dict_bsnh = {
|
||||
"query": q.detach().cpu().numpy(),
|
||||
"key": k.detach().cpu().numpy(),
|
||||
"value": v.detach().cpu().numpy(),
|
||||
"past_key": OrtValue.ortvalue_from_numpy(past_k.detach().cpu().numpy(), "cuda", device_id),
|
||||
"past_value": OrtValue.ortvalue_from_numpy(past_v.detach().cpu().numpy(), "cuda", device_id),
|
||||
}
|
||||
input_dict_bnsh = {
|
||||
"query": q.detach().cpu().numpy(),
|
||||
"key": k.detach().cpu().numpy(),
|
||||
"value": v.detach().cpu().numpy(),
|
||||
"past_key": OrtValue.ortvalue_from_numpy(
|
||||
past_k.transpose(1, 2).detach().cpu().numpy(), "cuda", 0
|
||||
),
|
||||
"past_value": OrtValue.ortvalue_from_numpy(
|
||||
past_v.transpose(1, 2).detach().cpu().numpy(), "cuda", 0
|
||||
),
|
||||
}
|
||||
if share_buffer:
|
||||
input_dict_bsnh["past_sequence_length"] = (
|
||||
torch.tensor([sp], device="cuda", dtype=torch.int32).detach().cpu().numpy()
|
||||
)
|
||||
input_dict_bnsh["past_sequence_length"] = (
|
||||
torch.tensor([sp], device="cuda", dtype=torch.int32).detach().cpu().numpy()
|
||||
)
|
||||
|
||||
io_binding_bsnh = bsnh_session.io_binding()
|
||||
io_binding_bnsh = bnsh_session.io_binding()
|
||||
bind_io(io_binding_bsnh, input_dict_bsnh, device_id, share_buffer)
|
||||
bind_io(io_binding_bnsh, input_dict_bnsh, device_id, share_buffer)
|
||||
average_gqa_bsnh_latency = benchmark_op(bsnh_session, io_binding_bsnh, repeats)
|
||||
average_gqa_bnsh_latency = benchmark_op(bnsh_session, io_binding_bnsh, repeats)
|
||||
|
||||
del bsnh_session
|
||||
del bnsh_session
|
||||
|
||||
# compute TFLOPS per second
|
||||
bsnh_speed = tflops_per_second(flops(b, s_q, s_kv, h, n_q), average_gqa_bsnh_latency)
|
||||
print(f"bsnh\t{b}\t{s_kv}\t{n_q}\t{h}\t{average_gqa_bsnh_latency * 1000:.2f}\t{bsnh_speed:.2f}")
|
||||
bnsh_speed = tflops_per_second(flops(b, s_q, s_kv, h, n_q), average_gqa_bnsh_latency)
|
||||
print(f"bnsh\t{b}\t{s_kv}\t{n_q}\t{h}\t{average_gqa_bnsh_latency * 1000:.2f}\t{bnsh_speed:.2f}")
|
||||
print("---------")
|
||||
if average_gqa_bsnh_latency > 10 * average_gqa_bnsh_latency:
|
||||
continue
|
||||
num_trials += 1
|
||||
mean_bsnh_lat += average_gqa_bsnh_latency
|
||||
mean_bnsh_lat += average_gqa_bnsh_latency
|
||||
mean_bsnh_lat /= num_trials
|
||||
mean_bnsh_lat /= num_trials
|
||||
print(f"average bsnh latency:\t{mean_bsnh_lat}")
|
||||
print(f"average bnsh latency:\t{mean_bnsh_lat}")
|
||||
"""
|
||||
for batch_size in [1, 4, 8, 16]:
|
||||
for num_heads, kv_num_heads in [(8, 8), (16, 8), (32, 8), (64, 8)]:
|
||||
for head_size in [64, 128]:
|
||||
plot_prompt_performance(
|
||||
sm=sm,
|
||||
batch_size=batch_size,
|
||||
num_heads=num_heads,
|
||||
kv_num_heads=kv_num_heads,
|
||||
max_seq_len=8192,
|
||||
head_size=head_size,
|
||||
)
|
||||
plot_token_performance(
|
||||
sm=sm,
|
||||
batch_size=batch_size,
|
||||
num_heads=num_heads,
|
||||
kv_num_heads=kv_num_heads,
|
||||
max_seq_len=8192,
|
||||
head_size=head_size,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tflops_test()
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
sm = major * 10 + minor
|
||||
|
||||
s = torch.cuda.Stream()
|
||||
with torch.cuda.stream(s), torch.no_grad():
|
||||
run_performance_test(sm)
|
||||
|
|
|
|||
Loading…
Reference in a new issue