From 1f509215bcbbba37e664316b40e2f3affd606396 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 8 May 2024 09:48:46 -0700 Subject: [PATCH] Fix GroupQueryAttention benchmark script (#20291) ### Description Fix a few issues in GQA: (1) memory efficient attention does not have bfloat16, need disable it when bfloat16 is used. (2) When prompt length is 1, it is not classified as prompt. (3) Fix benchmark_gqa.py (4) Add a comment about seqlen_k to avoid confusion. ### Motivation and Context https://github.com/microsoft/onnxruntime/pull/20279 --- .../cuda/bert/group_query_attention.cc | 7 +- .../cuda/bert/group_query_attention_helper.h | 9 +- .../cuda/bert/group_query_attention_impl.cu | 2 +- .../core/graph/contrib_ops/bert_defs.cc | 1 + .../test/python/transformers/benchmark_gqa.py | 625 ++++++++++-------- 5 files changed, 362 insertions(+), 282 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 1cd0c69233..3c968d6c8b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -45,7 +45,8 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0 && num_heads % kv_num_heads == 0); num_heads_ = static_cast(num_heads); kv_num_heads_ = static_cast(kv_num_heads); - is_past_bsnh_ = false; // info.GetAttrOrDefault("is_past_bsnh", 1) == 1; + is_past_bsnh_ = false; + is_unidirectional_ = true; local_window_size_ = static_cast(info.GetAttrOrDefault("local_window_size", -1)); do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1; rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1; @@ -59,7 +60,8 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) #endif #if USE_MEMORY_EFFICIENT_ATTENTION - disable_memory_efficient_attention_ = sizeof(T) != 2 || + // Memory efficient attention only supports float and float16, not bfloat16. + disable_memory_efficient_attention_ = std::is_same::value || ParseEnvironmentVariableWithDefault(attention::kDisableMemoryEfficientAttention, false); #else disable_memory_efficient_attention_ = true; @@ -160,7 +162,6 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { !disable_memory_efficient_attention_ && local_window_size_ == -1 && (parameters.head_size & 7) == 0 && - parameters.sequence_length <= parameters.seqlen_past_kv_cache + parameters.sequence_length && (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) && has_memory_efficient_attention(sm, sizeof(T) == 2); if (!use_flash_attention && !use_memory_efficient_attention && local_window_size_ != -1) { diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h index 8352397f68..91418b17e6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h @@ -32,8 +32,6 @@ Status CheckInputs(const Tensor* query, // query (Q) : (B, S, D) or (B, S, (D_q + 2 D_kv)) // key (K) : (B, S, D_kv) or nullptr // value (V) : (B, S, D_kv) or nullptr - ORT_UNUSED_PARAMETER(value); - AttentionQkvFormat qkv_format = Q_K_V_BSNH; AttentionQkvFormat past_kv_format = is_past_bsnh ? Q_K_V_BSNH : Q_K_V_BNSH; const bool is_packed_qkv = key == nullptr; @@ -241,7 +239,11 @@ Status CheckInputs(const Tensor* query, "Input 'cos_cache' and 'sin_cache' shall be both present or both absent."); } - bool is_prompt = sequence_length != 1; + bool is_prompt = (sequence_length == total_sequence_length); + if (!is_prompt && sequence_length != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "sequence_length shall be 1 when it is not prompt."); + } if (parameters != nullptr) { GroupQueryAttentionParameters* output_parameters = reinterpret_cast(parameters); @@ -256,7 +258,6 @@ Status CheckInputs(const Tensor* query, output_parameters->kv_num_heads = kv_num_heads; output_parameters->rotary_dim = rotary_dim; output_parameters->is_packed_qkv = is_packed_qkv; - output_parameters->is_unidirectional = true; output_parameters->is_prompt = is_prompt; output_parameters->scale = scale; output_parameters->qkv_format = qkv_format; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index 461d4e318a..2dd6e0acfd 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -632,7 +632,7 @@ Status FlashAttention( const int kv_num_heads = parameters.kv_num_heads; const int head_size = parameters.head_size; AttentionQkvFormat past_kv_format = parameters.past_kv_format; - bool is_causal = true; + bool is_causal = parameters.is_unidirectional; bool is_bf16 = std::is_same::value; void* query = reinterpret_cast(const_cast(data.query)); diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 6cac8f9a53..e9de04f8a9 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1104,6 +1104,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema::Optional) .Input(5, "seqlens_k", + // For prompt, the value is number of tokens (excluding padding) - 1. "1d Tensor of shape (batch_size). Indicates past sequence lengths for token generation case.", "M") .Input(6, diff --git a/onnxruntime/test/python/transformers/benchmark_gqa.py b/onnxruntime/test/python/transformers/benchmark_gqa.py index a9bef025a7..7fcd56bb8f 100644 --- a/onnxruntime/test/python/transformers/benchmark_gqa.py +++ b/onnxruntime/test/python/transformers/benchmark_gqa.py @@ -1,160 +1,207 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -""" -Benchmark performance of MultiHeadAttention with Nvidia GPU of Compute Capability 8.0, 8.6 or 8.9 in Linux: -sh benchmark_mha.sh -""" - import math -import random -import statistics -import time +from typing import Optional import torch from onnx import TensorProto, helper -from onnxruntime import InferenceSession, OrtValue, SessionOptions +from onnxruntime import InferenceSession, SessionOptions +from onnxruntime.transformers.io_binding_helper import CudaSession, GpuBindingManager -class InputFormats: - QKV_BSNH = 0 - QKV_BNSH = 1 +class AttentionConfig: + def __init__( + self, + operator: str, + batch_size: int, + sequence_length: int, + max_sequence_length: int, + past_sequence_length: int, + num_heads: int, + kv_num_heads: int, + head_size: int, + softmax_scale: Optional[float], + do_rotary: bool, + rotary_interleaved: bool, + device="cuda", + dtype=torch.float16, + share_buffer: bool = True, + is_packed_qkv: bool = False, + ): + self.operator = operator + self.batch_size = batch_size + self.sequence_length = sequence_length + self.max_sequence_length = max_sequence_length + self.past_sequence_length = past_sequence_length + self.num_heads = num_heads + self.kv_num_heads = kv_num_heads + self.head_size = head_size + self.softmax_scale = softmax_scale if softmax_scale is not None else 1.0 / (head_size**0.5) + + # Derived values + self.total_sequence_length = sequence_length + past_sequence_length + self.past_buffer_length = max_sequence_length if share_buffer else past_sequence_length + self.present_buffer_length = max_sequence_length if share_buffer else (past_sequence_length + sequence_length) + + self.do_rotary = do_rotary + self.rotary_interleaved = rotary_interleaved + self.device = device + + self.share_buffer = share_buffer + self.is_packed_qkv = is_packed_qkv + self.dtype = dtype + + def shape_dict(self): + return { + "query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), + "key": (self.batch_size, self.sequence_length, self.kv_num_heads * self.head_size), + "value": (self.batch_size, self.sequence_length, self.kv_num_heads * self.head_size), + "past_key": (self.batch_size, self.kv_num_heads, self.past_buffer_length, self.head_size), + "past_value": (self.batch_size, self.kv_num_heads, self.past_buffer_length, self.head_size), + "total_sequence_length": (1,), + "output": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), + "present_key": (self.batch_size, self.kv_num_heads, self.present_buffer_length, self.head_size), + "present_value": (self.batch_size, self.kv_num_heads, self.present_buffer_length, self.head_size), + "cos_cache": (self.max_sequence_length, (math.floor(self.head_size / 16) * 16) // 2), + "sin_cache": (self.max_sequence_length, (math.floor(self.head_size / 16) * 16) // 2), + } + + def get_cos_sin_cache(self, dtype): + rotary_fraction = 1.0 + rotary_dim = math.floor(int(rotary_fraction * self.head_size) / 16) * 16 + angle = torch.rand(self.max_sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi + cos = torch.cos(angle).to(dtype=dtype) + sin = torch.sin(angle).to(dtype=dtype) + return cos.to(device=self.device), sin.to(device=self.device) + + def random_inputs(self): + device = self.device + # bfloat16 is not supported in ORT python I/O binding API + dtype = torch.float16 + shape_dict = self.shape_dict() + + torch.manual_seed(123) + feeds = { + "query": torch.empty(shape_dict["query"], device=device, dtype=dtype).normal_(mean=0, std=0.1), + "key": torch.empty(shape_dict["key"], device=device, dtype=dtype).normal_(mean=0, std=0.1), + "value": torch.empty(shape_dict["value"], device=device, dtype=dtype).normal_(mean=0, std=0.1), + "past_key": torch.empty(shape_dict["past_key"], device=device, dtype=dtype).normal_(mean=0, std=0.1), + "past_value": torch.empty(shape_dict["past_value"], device=device, dtype=dtype).normal_(mean=0, std=0.1), + "total_sequence_length": torch.tensor([self.total_sequence_length], dtype=torch.int32), + } + + if self.do_rotary: + cos_cache, sin_cache = self.get_cos_sin_cache(dtype) + feeds["cos_cache"] = cos_cache + feeds["sin_cache"] = sin_cache + + return feeds -class Config: - batch_size = 0 - sequence_length = 0 - kv_sequence_length = 0 - past_sequence_length = 0 - num_heads = 0 - kv_num_heads = 0 - head_size = 0 +class GroupQueryAttentionConfig(AttentionConfig): + def __init__( + self, + batch_size: int, + sequence_length: int, + max_sequence_length: int, + past_sequence_length: int, + num_heads: int, + kv_num_heads: int, + head_size: int, + softmax_scale=None, + do_rotary: bool = False, + rotary_interleaved: bool = False, + device="cuda", + local_window_size: int = -1, + ): + super().__init__( + "GroupQueryAttention", + batch_size, + sequence_length, + max_sequence_length, + past_sequence_length, + num_heads, + kv_num_heads, + head_size, + softmax_scale, + do_rotary, + rotary_interleaved, + device, + ) + self.local_window_size = local_window_size - def __init__(self, b, s, s2, sp, n, n2, h): - self.batch_size = b - self.sequence_length = s - self.kv_sequence_length = s2 - self.past_sequence_length = sp - self.num_heads = n - self.kv_num_heads = n2 - self.head_size = h + def shape_dict(self): + shapes = super().shape_dict() + shapes.update( + { + "seqlens_k": (self.batch_size,), + } + ) + return shapes + + def random_inputs(self): + feeds = super().random_inputs() + k_seqlens = torch.ones((self.batch_size,), device=self.device, dtype=torch.int32) * self.total_sequence_length + feeds.update( + { + "seqlens_k": k_seqlens - 1, + } + ) + return feeds -def create_group_query_attention_graph_past( - config, causal=False, past_kv_format=InputFormats.QKV_BSNH, share_buffer=True -): - past_kv_seqlen = config.kv_sequence_length if share_buffer else config.past_sequence_length - present_kv_seqlen = ( - config.kv_sequence_length if share_buffer else config.past_sequence_length + config.sequence_length - ) +def create_group_query_attention_onnx_model(config: GroupQueryAttentionConfig): + assert config.dtype == torch.float16 + + float_type = TensorProto.FLOAT16 nodes = [ helper.make_node( "GroupQueryAttention", [ "query", - "key", - "value", + "key" if not config.is_packed_qkv else "", + "value" if not config.is_packed_qkv else "", "past_key", "past_value", - "past_sequence_length" if share_buffer else "", + "seqlens_k", + "total_sequence_length" if config.share_buffer else "", + "cos_cache" if config.do_rotary else "", + "sin_cache" if config.do_rotary else "", ], ["output", "present_key", "present_value"], "GroupQueryAttention_0", num_heads=config.num_heads, kv_num_heads=config.kv_num_heads, - unidirectional=1 if causal else 0, - is_past_bsnh=1 if past_kv_format == InputFormats.QKV_BSNH else 0, + scale=config.softmax_scale, + local_window_size=config.local_window_size, + do_rotary=1 if config.do_rotary else 0, + rotary_interleaved=config.rotary_interleaved, domain="com.microsoft", ), ] + shape_dict = config.shape_dict() graph_input = [ + helper.make_tensor_value_info("query", float_type, list(shape_dict["query"])), + helper.make_tensor_value_info("key", float_type, list(shape_dict["key"])), + helper.make_tensor_value_info("value", float_type, list(shape_dict["value"])), + helper.make_tensor_value_info("past_key", float_type, list(shape_dict["past_key"])), + helper.make_tensor_value_info("past_value", float_type, list(shape_dict["past_value"])), + helper.make_tensor_value_info("seqlens_k", TensorProto.INT32, list(shape_dict["seqlens_k"])), helper.make_tensor_value_info( - "query", - TensorProto.FLOAT16, - [ - config.batch_size, - config.sequence_length, - config.num_heads * config.head_size, - ], - ), - helper.make_tensor_value_info( - "key", - TensorProto.FLOAT16, - [ - config.batch_size, - config.sequence_length, - config.kv_num_heads * config.head_size, - ], - ), - helper.make_tensor_value_info( - "value", - TensorProto.FLOAT16, - [ - config.batch_size, - config.sequence_length, - config.kv_num_heads * config.head_size, - ], - ), - helper.make_tensor_value_info( - "past_key", - TensorProto.FLOAT16, - [ - config.batch_size, - past_kv_seqlen if past_kv_format == InputFormats.QKV_BSNH else config.kv_num_heads, - config.kv_num_heads if past_kv_format == InputFormats.QKV_BSNH else past_kv_seqlen, - config.head_size, - ], - ), - helper.make_tensor_value_info( - "past_value", - TensorProto.FLOAT16, - [ - config.batch_size, - past_kv_seqlen if past_kv_format == InputFormats.QKV_BSNH else config.kv_num_heads, - config.kv_num_heads if past_kv_format == InputFormats.QKV_BSNH else past_kv_seqlen, - config.head_size, - ], + "total_sequence_length", TensorProto.INT32, list(shape_dict["total_sequence_length"]) ), ] - if share_buffer: + + if config.do_rotary: graph_input += [ - helper.make_tensor_value_info( - "past_sequence_length", - TensorProto.INT32, - [1], - ) + helper.make_tensor_value_info("cos_cache", float_type, list(shape_dict["cos_cache"])), + helper.make_tensor_value_info("sin_cache", float_type, list(shape_dict["sin_cache"])), ] graph_output = [ - helper.make_tensor_value_info( - "output", - TensorProto.FLOAT16, - [config.batch_size, config.sequence_length, config.num_heads * config.head_size], - ), - helper.make_tensor_value_info( - "present_key", - TensorProto.FLOAT16, - [ - config.batch_size, - present_kv_seqlen if past_kv_format == InputFormats.QKV_BSNH else config.kv_num_heads, - config.kv_num_heads if past_kv_format == InputFormats.QKV_BSNH else present_kv_seqlen, - config.head_size, - ], - ), - helper.make_tensor_value_info( - "present_value", - TensorProto.FLOAT16, - [ - config.batch_size, - present_kv_seqlen if past_kv_format == InputFormats.QKV_BSNH else config.kv_num_heads, - config.kv_num_heads if past_kv_format == InputFormats.QKV_BSNH else present_kv_seqlen, - config.head_size, - ], - ), + helper.make_tensor_value_info("output", float_type, list(shape_dict["output"])), + helper.make_tensor_value_info("present_key", float_type, list(shape_dict["present_key"])), + helper.make_tensor_value_info("present_value", float_type, list(shape_dict["present_value"])), ] graph = helper.make_graph( @@ -168,172 +215,202 @@ def create_group_query_attention_graph_past( return model.SerializeToString() -def create_gqa_session( - config: Config, - causal: bool = False, - past_format=InputFormats.QKV_BSNH, - share_buffer: bool = True, -) -> InferenceSession: - onnx_model_str = create_group_query_attention_graph_past(config, causal, past_format, share_buffer) - sess_options = SessionOptions() - ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) +def create_session(onnx_model_str, cuda_provider_options=None) -> InferenceSession: + session_options = SessionOptions() + ort_session = InferenceSession( + onnx_model_str, + session_options, + providers=[("CUDAExecutionProvider", cuda_provider_options), "CPUExecutionProvider"], + ) return ort_session -def bind_io(io_binding, input_dict, device, share_buffer=True): - io_binding.bind_cpu_input("query", input_dict["query"]) - io_binding.bind_cpu_input("key", input_dict["key"]) - io_binding.bind_cpu_input("value", input_dict["value"]) - io_binding.bind_input( - "past_key", "cuda", 0, "float16", input_dict["past_key"].shape(), input_dict["past_key"].data_ptr() - ) - io_binding.bind_input( - "past_value", - "cuda", - 0, - "float16", - input_dict["past_value"].shape(), - input_dict["past_value"].data_ptr(), - ) - io_binding.bind_output("output") - if share_buffer: - io_binding.bind_cpu_input("past_sequence_length", input_dict["past_sequence_length"]) - io_binding.bind_output( - "present_key", - device_type="cuda", - device_id=device, - element_type="float16", - shape=input_dict["past_key"].shape(), - buffer_ptr=input_dict["past_key"].data_ptr(), +class OrtGroupQueryAttention: + """A wrapper of ORT GroupQueryAttention to test relevance and performance.""" + + def __init__(self, config: GroupQueryAttentionConfig): + device = config.device + cuda_provider_options = CudaSession.get_cuda_provider_options( + torch.cuda.current_device(), enable_cuda_graph=False, stream=torch.cuda.current_stream().cuda_stream ) - io_binding.bind_output( - "present_value", - device_type="cuda", - device_id=device, - element_type="float16", - shape=input_dict["past_value"].shape(), - buffer_ptr=input_dict["past_value"].data_ptr(), + onnx_model_str = create_group_query_attention_onnx_model(config) + self.ort_session = create_session(onnx_model_str, cuda_provider_options=cuda_provider_options) + self.gpu_binding_manager = GpuBindingManager( + ort_session=self.ort_session, + device=device, + stream=torch.cuda.current_stream().cuda_stream, + max_cuda_graphs=2, ) + buffer_sharing = {"past_key": "present_key", "past_value": "present_value"} + self.gpu_binding = self.gpu_binding_manager.get_binding( + config.shape_dict(), use_cuda_graph=False, buffer_sharing=buffer_sharing + ) + self.feed_dict = config.random_inputs() + + def infer(self): + return self.gpu_binding.infer(self.feed_dict) + + +def get_plot_algos(sm: int): + # GQA with local windows only works in sm=8x + if sm >= 80: + return { + "line_vals": ["ort_gqa", "ort_gqa_local"], + "line_names": ["ORT-GQA-Dense", "ORT-GQA-Local"], + "styles": [("red", "-"), ("blue", "-")], + } else: - io_binding.bind_output("present_key") - io_binding.bind_output("present_value") + return { + "line_vals": ["ort_gqa"], + "line_names": ["ORT-GQA-Dense"], + "styles": [("green", "-")], + } -def measure_latency(ort_session, io_binding): - start = time.time() - _ = ort_session.run_with_iobinding(io_binding) - end = time.time() - return end - start +def plot_prompt_performance( + sm: int, + batch_size=4, + num_heads=32, + kv_num_heads=8, + max_seq_len=8192, + head_size=128, +): + import triton + + algos = get_plot_algos(sm) + configs = [ + triton.testing.Benchmark( + x_names=["sequence_length"], + x_vals=[2**i for i in range(4, 14)], + line_arg="provider", + ylabel="ms", + **algos, + plot_name=f"prompt-sm{sm}-batch{batch_size}-head{num_heads}_kv{kv_num_heads}-d{head_size}-fp16", + args={ + "num_heads": num_heads, + "kv_num_heads": kv_num_heads, + "batch_size": batch_size, + "head_size": head_size, + }, + ) + ] + + @triton.testing.perf_report(configs) + def benchmark(batch_size, num_heads, kv_num_heads, sequence_length, head_size, provider, device="cuda"): + warmup = 15 + repeat = 100 + + config: GroupQueryAttentionConfig = GroupQueryAttentionConfig( + batch_size=batch_size, + sequence_length=sequence_length, + max_sequence_length=max_seq_len, + past_sequence_length=0, + num_heads=num_heads, + kv_num_heads=kv_num_heads, + head_size=head_size, + local_window_size=1024 if provider == "ort_gqa_local" else -1, + device=device, + ) + + obj = OrtGroupQueryAttention(config) + + ms = triton.testing.do_bench(obj.infer, warmup=warmup, rep=repeat) + return ms + + benchmark.run(save_path=".", print_data=True) -def flops(batch, q_seqlen, kv_seqlen, head_size, num_heads): - return 4 * batch * q_seqlen * kv_seqlen * num_heads * head_size +def plot_token_performance( + sm: int, + batch_size=4, + num_heads=32, + kv_num_heads=8, + max_seq_len=8192, + head_size=128, +): + import triton + + algos = get_plot_algos(sm) + configs = [ + triton.testing.Benchmark( + x_names=["past_sequence_length"], + x_vals=[2**i for i in range(4, 13)] + [max_seq_len - 1], + line_arg="provider", + ylabel="ms", + **algos, + plot_name=f"token-sm{sm}-batch{batch_size}-head{num_heads}_kv{kv_num_heads}-d{head_size}-fp16", + args={ + "num_heads": num_heads, + "kv_num_heads": kv_num_heads, + "batch_size": batch_size, + "head_size": head_size, + }, + ) + ] + + @triton.testing.perf_report(configs) + def benchmark( + batch_size, + num_heads, + kv_num_heads, + past_sequence_length, + head_size, + provider, + device="cuda", + ): + warmup = 15 + repeat = 100 + + config: GroupQueryAttentionConfig = GroupQueryAttentionConfig( + batch_size=batch_size, + sequence_length=1, + max_sequence_length=max_seq_len, + past_sequence_length=past_sequence_length, + num_heads=num_heads, + kv_num_heads=kv_num_heads, + head_size=head_size, + local_window_size=1024 if provider == "ort_gqa_local" else -1, + device=device, + ) + + obj = OrtGroupQueryAttention(config) + + ms = triton.testing.do_bench(obj.infer, warmup=warmup, rep=repeat) + return ms + + benchmark.run(save_path=".", print_data=True) -def tflops_per_second(flop, time): - return (flop / time / 10**12) if not math.isnan(time) else 0.0 +def run_performance_test(sm: int): + """ + Run performance tests for prompt and token generation. - -def benchmark_op(session, io_binding, repeats=100): - # warm up session - _ = measure_latency(session, io_binding) - - latency_list = [] - for _ in range(repeats): - latency = measure_latency(session, io_binding) - latency_list.append(latency) - return statistics.mean(latency_list) - - -def run_tflops_test(dtype=torch.float16, repeats: int = 100): - device_id = torch.cuda.current_device() - device = torch.device("cuda", device_id) - print("---- GQA BSNH vs GQA BNSH ----") - print("op\tbatch\ts_kv\theads\th_dim\tms\tTFLOPS") - mean_bsnh_lat = 0 - mean_bnsh_lat = 0 - num_trials = 0 - share_buffer = True - random.seed(69) - for b in [1, 3, 8, 16]: - for s_q, s_kv in [(1, 128), (128, 256), (512, 512), (128, 1024), (1, 2048)]: - for n_q, n_kv in [(8, 8), (16, 8), (32, 32), (12, 3), (128, 64)]: - for h in [32, 64, 128]: - sp = random.randint(1, s_kv - 1) if s_kv - 1 > 0 else 0 - config = Config(b, s_q, s_kv, sp, n_q, n_kv, h) - - bsnh_session = create_gqa_session( - config, - causal=False, - past_format=InputFormats.QKV_BSNH, - share_buffer=share_buffer, - ) - bnsh_session = create_gqa_session( - config, - causal=False, - past_format=InputFormats.QKV_BNSH, - share_buffer=share_buffer, - ) - - q = torch.randn(b, s_q, n_q * h, device=device, dtype=dtype) - kv = torch.randn(b, s_q, 2, n_kv * h, device=device, dtype=dtype) - k, v = kv.unbind(dim=2) - - past_kv = torch.rand(b, s_kv if share_buffer else sp, 2, n_kv, h, device=device, dtype=dtype) - past_k, past_v = past_kv.unbind(dim=2) - - input_dict_bsnh = { - "query": q.detach().cpu().numpy(), - "key": k.detach().cpu().numpy(), - "value": v.detach().cpu().numpy(), - "past_key": OrtValue.ortvalue_from_numpy(past_k.detach().cpu().numpy(), "cuda", device_id), - "past_value": OrtValue.ortvalue_from_numpy(past_v.detach().cpu().numpy(), "cuda", device_id), - } - input_dict_bnsh = { - "query": q.detach().cpu().numpy(), - "key": k.detach().cpu().numpy(), - "value": v.detach().cpu().numpy(), - "past_key": OrtValue.ortvalue_from_numpy( - past_k.transpose(1, 2).detach().cpu().numpy(), "cuda", 0 - ), - "past_value": OrtValue.ortvalue_from_numpy( - past_v.transpose(1, 2).detach().cpu().numpy(), "cuda", 0 - ), - } - if share_buffer: - input_dict_bsnh["past_sequence_length"] = ( - torch.tensor([sp], device="cuda", dtype=torch.int32).detach().cpu().numpy() - ) - input_dict_bnsh["past_sequence_length"] = ( - torch.tensor([sp], device="cuda", dtype=torch.int32).detach().cpu().numpy() - ) - - io_binding_bsnh = bsnh_session.io_binding() - io_binding_bnsh = bnsh_session.io_binding() - bind_io(io_binding_bsnh, input_dict_bsnh, device_id, share_buffer) - bind_io(io_binding_bnsh, input_dict_bnsh, device_id, share_buffer) - average_gqa_bsnh_latency = benchmark_op(bsnh_session, io_binding_bsnh, repeats) - average_gqa_bnsh_latency = benchmark_op(bnsh_session, io_binding_bnsh, repeats) - - del bsnh_session - del bnsh_session - - # compute TFLOPS per second - bsnh_speed = tflops_per_second(flops(b, s_q, s_kv, h, n_q), average_gqa_bsnh_latency) - print(f"bsnh\t{b}\t{s_kv}\t{n_q}\t{h}\t{average_gqa_bsnh_latency * 1000:.2f}\t{bsnh_speed:.2f}") - bnsh_speed = tflops_per_second(flops(b, s_q, s_kv, h, n_q), average_gqa_bnsh_latency) - print(f"bnsh\t{b}\t{s_kv}\t{n_q}\t{h}\t{average_gqa_bnsh_latency * 1000:.2f}\t{bnsh_speed:.2f}") - print("---------") - if average_gqa_bsnh_latency > 10 * average_gqa_bnsh_latency: - continue - num_trials += 1 - mean_bsnh_lat += average_gqa_bsnh_latency - mean_bnsh_lat += average_gqa_bnsh_latency - mean_bsnh_lat /= num_trials - mean_bnsh_lat /= num_trials - print(f"average bsnh latency:\t{mean_bsnh_lat}") - print(f"average bnsh latency:\t{mean_bnsh_lat}") + """ + for batch_size in [1, 4, 8, 16]: + for num_heads, kv_num_heads in [(8, 8), (16, 8), (32, 8), (64, 8)]: + for head_size in [64, 128]: + plot_prompt_performance( + sm=sm, + batch_size=batch_size, + num_heads=num_heads, + kv_num_heads=kv_num_heads, + max_seq_len=8192, + head_size=head_size, + ) + plot_token_performance( + sm=sm, + batch_size=batch_size, + num_heads=num_heads, + kv_num_heads=kv_num_heads, + max_seq_len=8192, + head_size=head_size, + ) if __name__ == "__main__": - run_tflops_test() + major, minor = torch.cuda.get_device_capability() + sm = major * 10 + minor + + s = torch.cuda.Stream() + with torch.cuda.stream(s), torch.no_grad(): + run_performance_test(sm)