From a2b0a69dccc398d4af042ec0adb80d3eecdbcccf Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 12 Jun 2024 13:04:25 -0700 Subject: [PATCH] Update MultiHeadAttention benchmark to test CPU (#20972) ### Description MultiHeadAttention benchmark script only supports cuda provider right now. This updates the script to support testing cpu operator and ploting gpu latency. ### Motivation and Context Benchmark for the coming cpu flash attention. --- .../test/python/transformers/benchmark_mha.py | 659 ++++++++++++------ 1 file changed, 449 insertions(+), 210 deletions(-) diff --git a/onnxruntime/test/python/transformers/benchmark_mha.py b/onnxruntime/test/python/transformers/benchmark_mha.py index 1e75268ea6..b87d476ce5 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.py +++ b/onnxruntime/test/python/transformers/benchmark_mha.py @@ -12,125 +12,227 @@ import math import os import statistics import time +from typing import List import torch from onnx import TensorProto, helper -from onnxruntime import InferenceSession +from onnxruntime import InferenceSession, get_available_providers from onnxruntime.transformers.io_binding_helper import CudaSession class InputFormats: - Q_K_V_BSNH = 0 + Q_K_V_BSNH_BSNH_BSNH = 0 QKV_BSN3H = 1 Q_KV_BSNH_BSN2H = 2 + Q_K_V_BSNH_BNSH_BNSH = 3 # For cross attention @staticmethod def input_format_str(format: int) -> str: - return "QKV" if format == 1 else "Q,KV" if format == 2 else "Q,K,V" + names = InputFormats.get_name_list() + return names[format] + + @staticmethod + def convert(format_str: str) -> int: + names = InputFormats.get_name_list() + return names.index(format_str) + + @staticmethod + def get_name_list() -> List[str]: + return ["Q,K,V", "QKV", "Q,KV", "Q,K',V'"] -class Config: - batch_size: int = 0 - sequence_length: int = 0 - kv_sequence_length: int = 0 - num_heads: int = 0 - head_size: int = 0 - causal: bool = False - input_format: int = InputFormats.Q_K_V_BSNH - - def __init__(self, b, s, s2, n, h, causal, input_format): - self.batch_size = b - self.sequence_length = s - self.kv_sequence_length = s2 - self.num_heads = n - self.head_size = h +class MultiHeadAttentionConfig: + def __init__( + self, + batch_size: int, + sequence_length: int, + num_heads: int, + head_size: int, + causal: bool, + past_sequence_length: int = 0, + kv_sequence_length=None, + max_cache_sequence_length=None, + softmax_scale: float = 0.0, + device="cuda", + dtype=torch.float16, + use_kv_cache: bool = False, + share_past_present_buffer: bool = False, + input_format: int = InputFormats.Q_K_V_BSNH_BSNH_BSNH, + ): + self.operator = "MultiHeadAttention" + self.batch_size = batch_size + self.sequence_length = sequence_length + self.kv_sequence_length = kv_sequence_length or sequence_length + self.max_cache_sequence_length = max_cache_sequence_length + self.past_sequence_length = past_sequence_length + self.num_heads = num_heads + self.head_size = head_size self.causal = causal + self.softmax_scale = softmax_scale or (1.0 / (head_size**0.5)) + + self.use_kv_cache = use_kv_cache + if not use_kv_cache: + assert past_sequence_length == 0 + + if input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH: + # cross attention does not have past state + assert not use_kv_cache + + # Derived values + self.total_sequence_length = self.kv_sequence_length + past_sequence_length + self.past_buffer_length = self.max_cache_sequence_length if share_past_present_buffer else past_sequence_length + self.present_buffer_length = ( + self.max_cache_sequence_length if share_past_present_buffer else self.total_sequence_length + ) + + self.device = device + self.share_past_present_buffer = share_past_present_buffer self.input_format = input_format + self.is_packed_qkv = input_format == InputFormats.QKV_BSN3H + self.is_packed_kv = input_format == InputFormats.Q_KV_BSNH_BSN2H + self.dtype = dtype + + def shape_dict(self, input_format=None): + input_format = input_format or self.input_format + if input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH: + # cross attention does not have past state + return { + "query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), + "key": (self.batch_size, self.num_heads, self.sequence_length, self.head_size), + "value": (self.batch_size, self.num_heads, self.sequence_length, self.head_size), + "output": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), + } + + if self.use_kv_cache: + shapes = { + "past_key": (self.batch_size, self.num_heads, self.past_buffer_length, self.head_size), + "past_value": (self.batch_size, self.num_heads, self.past_buffer_length, self.head_size), + "output": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), + "present_key": (self.batch_size, self.num_heads, self.present_buffer_length, self.head_size), + "present_value": (self.batch_size, self.num_heads, self.present_buffer_length, self.head_size), + } + else: + shapes = { + "output": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), + } + + if input_format == InputFormats.QKV_BSN3H: + shapes.update({"query": (self.batch_size, self.sequence_length, self.num_heads, 3, self.head_size)}) + elif input_format == InputFormats.Q_KV_BSNH_BSN2H: + shapes.update( + { + "query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), + "key": (self.batch_size, self.sequence_length, self.num_heads, 2, self.head_size), + } + ) + else: # input_format == InputFormats.Q_K_V_BSNH_BSNH_BSNH + shapes.update( + { + "query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), + "key": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), + "value": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), + } + ) + return shapes + + def random_inputs(self, seed: int = 123): + device = self.device + dtype = self.dtype + + shape_dict = self.shape_dict() + + if seed > 0: + torch.manual_seed(seed) + + shape = (self.batch_size, self.sequence_length, self.num_heads, self.head_size) + q = torch.empty(shape, device=device, dtype=dtype).normal_(mean=0, std=0.1) + k = torch.empty(shape, device=device, dtype=dtype).normal_(mean=0, std=0.1) + v = torch.empty(shape, device=device, dtype=dtype).normal_(mean=0, std=0.1) + k_bnsh = k.transpose(1, 2) + v_bnsh = v.transpose(1, 2) + + if self.input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH: + return { + "query": q.reshape(shape_dict["query"]), + "key": k_bnsh.contiguous(), + "value": v_bnsh.contiguous(), + } + + feeds = {} + if self.use_kv_cache: + feeds.update( + { + "past_key": torch.empty(shape_dict["past_key"], device=device, dtype=dtype).normal_( + mean=0, std=0.1 + ), + "past_value": torch.empty(shape_dict["past_value"], device=device, dtype=dtype).normal_( + mean=0, std=0.1 + ), + } + ) + + if self.input_format == InputFormats.Q_K_V_BSNH_BSNH_BSNH: + feeds.update( + { + "query": q.reshape(shape_dict["query"]), + "key": k.reshape(shape_dict["key"]), + "value": v.reshape(shape_dict["value"]), + } + ) + elif self.input_format == InputFormats.QKV_BSN3H: + query = q.view(self.batch_size * self.sequence_length, self.num_heads, self.head_size) + key = k.view(self.batch_size * self.sequence_length, self.num_heads, self.head_size) + value = v.view(self.batch_size * self.sequence_length, self.num_heads, self.head_size) + feeds["query"] = torch.dstack((query, key, value)).reshape(shape_dict["query"]).contiguous() + elif self.input_format == InputFormats.Q_KV_BSNH_BSN2H: + key = k.view(self.batch_size * self.sequence_length, self.num_heads, self.head_size) + value = v.view(self.batch_size * self.sequence_length, self.num_heads, self.head_size) + feeds["query"] = q.reshape(shape_dict["query"]) + feeds["key"] = torch.dstack((key, value)).reshape(shape_dict["key"]).contiguous() + + return feeds + + def get_input_output_names(self): + if self.input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH: + return ["query", "key"], ["output"] + + if self.input_format == InputFormats.QKV_BSN3H: + inputs, outputs = ["query"], ["output"] + elif self.input_format == InputFormats.Q_KV_BSNH_BSN2H: + inputs, outputs = ["query", "key"], ["output"] + else: + inputs, outputs = ["query", "key", "value"], ["output"] + + if self.use_kv_cache: + return [*input, "past_key", "past_value"], [*outputs, "present_key", "present_value"] + else: + return inputs, outputs -def create_multihead_attention_graph(config: Config): - query = helper.make_tensor_value_info( - "query", - TensorProto.FLOAT16, - [ - config.batch_size, - config.sequence_length, - config.num_heads * config.head_size, - ], - ) - - key = helper.make_tensor_value_info( - "key", - TensorProto.FLOAT16, - [ - config.batch_size, - config.kv_sequence_length, - config.num_heads * config.head_size, - ], - ) - - value = helper.make_tensor_value_info( - "value", - TensorProto.FLOAT16, - [ - config.batch_size, - config.kv_sequence_length, - config.num_heads * config.head_size, - ], - ) - - packed_qkv = helper.make_tensor_value_info( - "query", - TensorProto.FLOAT16, - [ - config.batch_size, - config.sequence_length, - config.num_heads, - 3, - config.head_size, - ], - ) - - packed_kv = helper.make_tensor_value_info( - "key", - TensorProto.FLOAT16, - [ - config.batch_size, - config.kv_sequence_length, - config.num_heads, - 2, - config.head_size, - ], - ) - - if config.input_format == InputFormats.QKV_BSN3H: - input_names = ["query"] - inputs = [packed_qkv] - elif config.input_format == InputFormats.Q_KV_BSNH_BSN2H: - input_names = ["query", "key"] - inputs = [query, packed_kv] - else: # input_format==InputFormats.Q_K_V_BSNH - input_names = ["query", "key", "value"] - inputs = [query, key, value] - +def create_multi_head_attention_onnx_model(config: MultiHeadAttentionConfig): + input_names, output_names = config.get_input_output_names() + float_type = TensorProto.FLOAT16 if config.dtype == torch.float16 else TensorProto.FLOAT nodes = [ helper.make_node( "MultiHeadAttention", input_names, - ["output"], + output_names, "MultiHeadAttention_0", num_heads=config.num_heads, + scale=config.softmax_scale, domain="com.microsoft", ), ] + shape_dict = config.shape_dict() + inputs = [ + helper.make_tensor_value_info(input_name, float_type, list(shape_dict[input_name])) + for input_name in input_names + ] outputs = [ - helper.make_tensor_value_info( - "output", - TensorProto.FLOAT16, - [config.batch_size, config.sequence_length, config.num_heads * config.head_size], - ), + helper.make_tensor_value_info(output_name, float_type, list(shape_dict[output_name])) + for output_name in output_names ] graph = helper.make_graph( @@ -144,41 +246,46 @@ def create_multihead_attention_graph(config: Config): return model.SerializeToString() -def input_output_shapes(config: Config): - if config.input_format == InputFormats.QKV_BSN3H: - return { - "query": (config.batch_size, config.sequence_length, config.num_heads, 3, config.head_size), - "output": (config.batch_size, config.sequence_length, config.num_heads * config.head_size), - } - - if config.input_format == InputFormats.Q_KV_BSNH_BSN2H: - return { - "query": (config.batch_size, config.sequence_length, config.num_heads * config.head_size), - "key": (config.batch_size, config.kv_sequence_length, config.num_heads, 2, config.head_size), - "output": (config.batch_size, config.sequence_length, config.num_heads * config.head_size), - } - - return { - "query": (config.batch_size, config.sequence_length, config.num_heads * config.head_size), - "key": (config.batch_size, config.kv_sequence_length, config.num_heads * config.head_size), - "value": (config.batch_size, config.kv_sequence_length, config.num_heads * config.head_size), - "output": (config.batch_size, config.sequence_length, config.num_heads * config.head_size), - } - - def create_session( - device_id: int, config: Config, provider: str = "CUDAExecutionProvider", enable_cuda_graph: bool = False + config: MultiHeadAttentionConfig, + provider: str = "CUDAExecutionProvider", + enable_cuda_graph: bool = False, + device_id: int = 0, ) -> CudaSession: - onnx_model_str = create_multihead_attention_graph(config) - provider_options = CudaSession.get_cuda_provider_options(device_id, enable_cuda_graph) - ort_session = InferenceSession(onnx_model_str, providers=[(provider, provider_options), "CPUExecutionProvider"]) - device = torch.device("cuda", device_id) + onnx_model_str = create_multi_head_attention_onnx_model(config) + + if provider == "CUDAExecutionProvider": + provider_options = CudaSession.get_cuda_provider_options(device_id, enable_cuda_graph) + providers = [(provider, provider_options), "CPUExecutionProvider"] + device = torch.device("cuda", device_id) + else: + providers = ["CPUExecutionProvider"] + device = torch.device("cpu") + + ort_session = InferenceSession(onnx_model_str, providers=providers) cuda_session = CudaSession(ort_session, device, enable_cuda_graph) - shape_dict = input_output_shapes(config) + shape_dict = config.shape_dict() cuda_session.allocate_buffers(shape_dict) return cuda_session +class OrtMultiHeadAttention: + """A wrapper of ORT MultiHeadAttention to test relevance and performance.""" + + def __init__( + self, + config: MultiHeadAttentionConfig, + provider: str = "CUDAExecutionProvider", + enable_cuda_graph: bool = False, + device_id: int = 0, + ): + self.ort_session = create_session(config, provider, enable_cuda_graph=enable_cuda_graph, device_id=device_id) + self.feed_dict = config.random_inputs() + + def infer(self): + return self.ort_session.infer(self.feed_dict) + + def measure_latency(cuda_session: CudaSession, input_dict): start = time.time() _ = cuda_session.infer(input_dict) @@ -194,7 +301,7 @@ def tflops_per_second(flop, time): return (flop / time / 10**12) if not math.isnan(time) else 0.0 -def get_sm8x_kernel_name(config: Config) -> str: +def get_gpu_kernel_name(config: MultiHeadAttentionConfig) -> str: # This classification is for Nvidia GPU of Compute Capability 8.* like A100. # Note that some kernel might not exist in older or newer GPUs. if os.getenv("ORT_DISABLE_FLASH_ATTENTION") != "1": @@ -218,49 +325,72 @@ def get_sm8x_kernel_name(config: Config) -> str: return "Unfused" -def run_tflops_test(dtype=torch.float16, enable_cuda_graph: bool = False, repeats: int = 100): - device_id = torch.cuda.current_device() - device = torch.device("cuda", device_id) +def get_cpu_kernel_name() -> str: + if os.getenv("ORT_DISABLE_FLASH_ATTENTION") != "1": + return "CPU:Flash" + return "CPU:Unfused" - # (batch_size, sequence_length, num_heads, head_size) - configs = [ - (32, 512, 64, 32), - (32, 512, 128, 16), - (16, 1024, 64, 32), - (16, 1024, 128, 16), - (8, 2048, 64, 32), - (8, 2048, 128, 16), - (4, 4096, 64, 32), - (4, 4096, 128, 16), - (2, 8192, 64, 32), - (2, 8192, 128, 16), - (1, 16384, 64, 32), - (1, 16384, 128, 16), - # stable diffusion - (1, 4096, 8, 40), - (1, 4096, 8, 80), - (1, 4096, 8, 160), - (4, 4096, 8, 40), - (4, 4096, 8, 80), - (4, 4096, 8, 160), - (1, 16384, 8, 40), - (1, 16384, 8, 80), - (1, 16384, 8, 160), - # bert-base - (128, 128, 12, 64), - (64, 128, 12, 64), - (128, 384, 12, 64), - (64, 384, 12, 64), - (128, 512, 12, 64), - (64, 512, 12, 64), - # TNLGv4 - (4, 2048, 32, 128), - (4, 4096, 32, 128), - (8, 2048, 32, 128), - (8, 4096, 32, 128), - ] - print(f"enable_cuda_graph={enable_cuda_graph}") +def run_tflops_test(use_gpu: bool = True, enable_cuda_graph: bool = False, repeats: int = 100): + if use_gpu: + device_id = torch.cuda.current_device() + device = torch.device("cuda", device_id) + formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH, InputFormats.Q_KV_BSNH_BSN2H, InputFormats.QKV_BSN3H] + provider = "CUDAExecutionProvider" + print(f"enable_cuda_graph={enable_cuda_graph}") + else: + device_id = 0 + device = torch.device("cpu") + formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH] + enable_cuda_graph = False + provider = "CPUExecutionProvider" + + if use_gpu: + # (batch_size, sequence_length, past_sequence_length, num_heads, head_size, run_unfused) + configs = [ + (32, 512, 0, 64, 32, True), + (32, 512, 0, 128, 16, True), + (16, 1024, 0, 64, 32, True), + (16, 1024, 0, 128, 16, True), + (8, 2048, 0, 64, 32, True), + (8, 2048, 0, 128, 16, False), + (4, 4096, 0, 64, 32, False), + (4, 4096, 0, 128, 16, False), + (2, 8192, 0, 64, 32, False), + (2, 8192, 0, 128, 16, False), + (1, 16384, 0, 64, 32, False), + (1, 16384, 0, 128, 16, False), + # stable diffusion + (1, 4096, 0, 8, 40, False), + (1, 4096, 0, 8, 80, False), + (1, 4096, 0, 8, 160, False), + (4, 4096, 0, 8, 40, False), + (4, 4096, 0, 8, 80, False), + (4, 4096, 0, 8, 160, False), + (1, 16384, 0, 8, 40, False), + (1, 16384, 0, 8, 80, False), + (1, 16384, 0, 8, 160, False), + # bert-base + (128, 128, 0, 12, 64, True), + (64, 128, 0, 12, 64, True), + (128, 384, 0, 12, 64, True), + (64, 384, 0, 12, 64, True), + (128, 512, 0, 12, 64, True), + (64, 512, 0, 12, 64, True), + # TNLGv4 + (4, 2048, 0, 32, 128, True), + (4, 4096, 0, 32, 128, False), + (8, 2048, 0, 32, 128, False), + (8, 4096, 0, 32, 128, False), + ] + else: + configs = [ + (1, 128, 0, 32, 128, True), + (1, 256, 0, 32, 128, True), + (1, 512, 0, 32, 128, True), + (1, 1024, 0, 32, 128, True), + (1, 2048, 0, 32, 128, True), + ] # List of environment variables to enable/disable attention kernels print("Environment Variables:") @@ -277,67 +407,176 @@ def run_tflops_test(dtype=torch.float16, enable_cuda_graph: bool = False, repeat value = os.getenv(name) if value is not None: print(f"{name}={value}") - print() - print("format\tcausal\tbatch\tseqlen\theads\th_dim\tms\tTFLOPS\tkernel") + print("\nformat\tcausal\tbatch\tseqlen\theads\th_dim\tms\tTFLOPS\tkernel") causal = False - for input_format in [InputFormats.Q_K_V_BSNH, InputFormats.Q_KV_BSNH_BSN2H, InputFormats.QKV_BSN3H]: - for batch_size, sequence_length, num_heads, head_size in configs: - config = Config(batch_size, sequence_length, sequence_length, num_heads, head_size, causal, input_format) - session = create_session(device_id, config, enable_cuda_graph=enable_cuda_graph) - - qkv = torch.randn(batch_size, sequence_length, 3, num_heads, head_size, device=device, dtype=dtype) - q, k, v = qkv.unbind(dim=2) - - if input_format == InputFormats.QKV_BSN3H: - if config.sequence_length != config.kv_sequence_length: - continue - q = torch.reshape(q, (-1, config.num_heads, config.head_size)) - k = torch.reshape(k, (-1, config.num_heads, config.head_size)) - v = torch.reshape(v, (-1, config.num_heads, config.head_size)) - packed_qkv = torch.dstack((q, k, v)).reshape( - config.batch_size, config.sequence_length, config.num_heads, 3, config.head_size + for input_format in formats: + for batch_size, sequence_length, past_sequence_length, num_heads, head_size, enable_unfused in configs: + for use_kv_cache in [False]: + config = MultiHeadAttentionConfig( + batch_size=batch_size, + sequence_length=sequence_length, + num_heads=num_heads, + head_size=head_size, + causal=True, + use_kv_cache=use_kv_cache, + past_sequence_length=past_sequence_length, + max_cache_sequence_length=None, + kv_sequence_length=None, + device=device, + dtype=torch.float16 if use_gpu else torch.float, + share_past_present_buffer=False, + input_format=input_format, ) - input_dict = {"query": packed_qkv.contiguous()} - elif input_format == InputFormats.Q_KV_BSNH_BSN2H: - q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) - k = torch.reshape(k, (-1, config.num_heads, config.head_size)) - v = torch.reshape(v, (-1, config.num_heads, config.head_size)) - packed_kv = torch.dstack((k, v)).reshape( - config.batch_size, config.sequence_length, config.num_heads, 2, config.head_size + + session = create_session( + config, provider=provider, enable_cuda_graph=enable_cuda_graph, device_id=device_id ) - input_dict = {"query": q.contiguous(), "key": packed_kv.contiguous()} - else: # input_format == InputFormats.Q_K_V_BSNH - q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) - k = torch.reshape(k, (config.batch_size, config.kv_sequence_length, -1)) - v = torch.reshape(v, (config.batch_size, config.kv_sequence_length, -1)) - input_dict = { - "query": q.contiguous(), - "key": k.contiguous(), - "value": v.contiguous(), - } - # warm up session - _ = measure_latency(session, input_dict) + if use_gpu: + kernel = get_gpu_kernel_name(config) + else: + kernel = get_cpu_kernel_name() - latency_list = [] - for _ in range(repeats): - latency = measure_latency(session, input_dict) - latency_list.append(latency) - average_latency = statistics.mean(latency_list) + if kernel == "Unfused": + # Skip large sequence length for Unfused kernel to avoid OOM. + if not enable_unfused: + continue - del session + # Unfused kernel does not support packed QKV or packed KV formats. + if input_format not in [InputFormats.Q_K_V_BSNH_BSNH_BSNH]: + continue - # compute TFLOPS per second - speed = tflops_per_second(flops(batch_size, sequence_length, head_size, num_heads, causal), average_latency) + input_dict = config.random_inputs() - kernel = get_sm8x_kernel_name(config) - format = InputFormats.input_format_str(input_format) - print( - f"{format}\t{causal}\t{batch_size}\t{sequence_length}\t{num_heads}\t{head_size}\t{average_latency * 1000:.2f}\t{speed:.2f}\t{kernel}" - ) + # warm up session + _ = measure_latency(session, input_dict) + + latency_list = [] + for _ in range(repeats): + latency = measure_latency(session, input_dict) + latency_list.append(latency) + average_latency = statistics.mean(latency_list) + + del session + + # compute TFLOPS per second + speed = tflops_per_second( + flops(batch_size, sequence_length, head_size, num_heads, causal), average_latency + ) + + format = InputFormats.input_format_str(input_format) + print( + f"{format}\t{causal}\t{batch_size}\t{sequence_length}\t{num_heads}\t{head_size}\t{average_latency * 1000:.2f}\t{speed:.2f}\t{kernel}" + ) + + +def plot_prompt_performance( + sm: int, + model_name: str, + batch_size: int, + num_heads: int, + head_size: int, + max_seq_len: int, +): + import triton + + formats = InputFormats.get_name_list() + + # Exclude cross attention since kernel crashes for some configuration. + formats = formats[:-1] + + settings = { + "line_vals": formats, + "line_names": ["ORT-MHA:" + name for name in formats], + "styles": [("red", "solid"), ("yellow", "dashdot"), ("blue", "dashed"), ("green", "dotted")][0 : len(formats)], + } + + configs = [ + triton.testing.Benchmark( + x_names=["sequence_length"], + x_vals=[2**i for i in range(6, 17) if 2**i <= max_seq_len], + line_arg="input_format", + ylabel="ms", + **settings, + plot_name=f"prompt-sm{sm}-{model_name}-b{batch_size}-h{num_heads}_{head_size}-fp16", + args={ + "batch_size": batch_size, + "num_heads": num_heads, + "head_size": head_size, + }, + ) + ] + + @triton.testing.perf_report(configs) + def benchmark( + input_format: str, + sequence_length: int, + batch_size: int, + num_heads: int, + head_size: int, + device="cuda", + ): + warmup = 15 + repeat = 100 + + config: MultiHeadAttentionConfig = MultiHeadAttentionConfig( + batch_size=batch_size, + sequence_length=sequence_length, + num_heads=num_heads, + head_size=head_size, + causal=True, + past_sequence_length=0, + kv_sequence_length=sequence_length if input_format == InputFormats.get_name_list()[-1] else None, + max_cache_sequence_length=max_seq_len, + device=device, + use_kv_cache=False, + input_format=InputFormats.convert(input_format), + ) + + obj = OrtMultiHeadAttention(config, enable_cuda_graph=False) + ms = triton.testing.do_bench(obj.infer, warmup=warmup, rep=repeat) + return ms + + benchmark.run(save_path=".", print_data=True) + + +def run_performance_test(sm: int): + """ + Run performance tests for prompt and token generation. + + """ + configures = [ + (1, 32, 128, 8192, "TNLGv4"), + (4, 32, 128, 8192, "TNLGv4"), + (1, 12, 64, 1024, "BertBase"), + (16, 12, 64, 1024, "BertBase"), + (1, 16, 64, 1024, "BertLarge"), + (8, 16, 64, 1024, "BertLarge"), + ] + + for batch_size, num_heads, head_size, max_seq_len, model_name in configures: + plot_prompt_performance( + sm=sm, + batch_size=batch_size, + num_heads=num_heads, + head_size=head_size, + max_seq_len=max_seq_len, + model_name=model_name, + ) if __name__ == "__main__": - run_tflops_test(enable_cuda_graph=False) + if torch.cuda.is_available() and "CUDAExecutionProvider" in get_available_providers(): + # Test CUDA provider + major, minor = torch.cuda.get_device_capability() + sm = major * 10 + minor + s = torch.cuda.Stream() + with torch.cuda.stream(s), torch.no_grad(): + run_performance_test(sm) + + run_tflops_test(use_gpu=True, enable_cuda_graph=True) + + # Test CPU provider + run_tflops_test(use_gpu=False, enable_cuda_graph=False)