mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-17 21:10:43 +00:00
Update MultiHeadAttention benchmark to test CPU (#20972)
### Description MultiHeadAttention benchmark script only supports cuda provider right now. This updates the script to support testing cpu operator and ploting gpu latency. ### Motivation and Context Benchmark for the coming cpu flash attention.
This commit is contained in:
parent
99f0fe3fae
commit
a2b0a69dcc
1 changed files with 449 additions and 210 deletions
|
|
@ -12,125 +12,227 @@ import math
|
|||
import os
|
||||
import statistics
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from onnx import TensorProto, helper
|
||||
|
||||
from onnxruntime import InferenceSession
|
||||
from onnxruntime import InferenceSession, get_available_providers
|
||||
from onnxruntime.transformers.io_binding_helper import CudaSession
|
||||
|
||||
|
||||
class InputFormats:
|
||||
Q_K_V_BSNH = 0
|
||||
Q_K_V_BSNH_BSNH_BSNH = 0
|
||||
QKV_BSN3H = 1
|
||||
Q_KV_BSNH_BSN2H = 2
|
||||
Q_K_V_BSNH_BNSH_BNSH = 3 # For cross attention
|
||||
|
||||
@staticmethod
|
||||
def input_format_str(format: int) -> str:
|
||||
return "QKV" if format == 1 else "Q,KV" if format == 2 else "Q,K,V"
|
||||
names = InputFormats.get_name_list()
|
||||
return names[format]
|
||||
|
||||
@staticmethod
|
||||
def convert(format_str: str) -> int:
|
||||
names = InputFormats.get_name_list()
|
||||
return names.index(format_str)
|
||||
|
||||
@staticmethod
|
||||
def get_name_list() -> List[str]:
|
||||
return ["Q,K,V", "QKV", "Q,KV", "Q,K',V'"]
|
||||
|
||||
|
||||
class Config:
|
||||
batch_size: int = 0
|
||||
sequence_length: int = 0
|
||||
kv_sequence_length: int = 0
|
||||
num_heads: int = 0
|
||||
head_size: int = 0
|
||||
causal: bool = False
|
||||
input_format: int = InputFormats.Q_K_V_BSNH
|
||||
|
||||
def __init__(self, b, s, s2, n, h, causal, input_format):
|
||||
self.batch_size = b
|
||||
self.sequence_length = s
|
||||
self.kv_sequence_length = s2
|
||||
self.num_heads = n
|
||||
self.head_size = h
|
||||
class MultiHeadAttentionConfig:
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int,
|
||||
sequence_length: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
causal: bool,
|
||||
past_sequence_length: int = 0,
|
||||
kv_sequence_length=None,
|
||||
max_cache_sequence_length=None,
|
||||
softmax_scale: float = 0.0,
|
||||
device="cuda",
|
||||
dtype=torch.float16,
|
||||
use_kv_cache: bool = False,
|
||||
share_past_present_buffer: bool = False,
|
||||
input_format: int = InputFormats.Q_K_V_BSNH_BSNH_BSNH,
|
||||
):
|
||||
self.operator = "MultiHeadAttention"
|
||||
self.batch_size = batch_size
|
||||
self.sequence_length = sequence_length
|
||||
self.kv_sequence_length = kv_sequence_length or sequence_length
|
||||
self.max_cache_sequence_length = max_cache_sequence_length
|
||||
self.past_sequence_length = past_sequence_length
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.causal = causal
|
||||
self.softmax_scale = softmax_scale or (1.0 / (head_size**0.5))
|
||||
|
||||
self.use_kv_cache = use_kv_cache
|
||||
if not use_kv_cache:
|
||||
assert past_sequence_length == 0
|
||||
|
||||
if input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH:
|
||||
# cross attention does not have past state
|
||||
assert not use_kv_cache
|
||||
|
||||
# Derived values
|
||||
self.total_sequence_length = self.kv_sequence_length + past_sequence_length
|
||||
self.past_buffer_length = self.max_cache_sequence_length if share_past_present_buffer else past_sequence_length
|
||||
self.present_buffer_length = (
|
||||
self.max_cache_sequence_length if share_past_present_buffer else self.total_sequence_length
|
||||
)
|
||||
|
||||
self.device = device
|
||||
self.share_past_present_buffer = share_past_present_buffer
|
||||
self.input_format = input_format
|
||||
self.is_packed_qkv = input_format == InputFormats.QKV_BSN3H
|
||||
self.is_packed_kv = input_format == InputFormats.Q_KV_BSNH_BSN2H
|
||||
self.dtype = dtype
|
||||
|
||||
def shape_dict(self, input_format=None):
|
||||
input_format = input_format or self.input_format
|
||||
if input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH:
|
||||
# cross attention does not have past state
|
||||
return {
|
||||
"query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size),
|
||||
"key": (self.batch_size, self.num_heads, self.sequence_length, self.head_size),
|
||||
"value": (self.batch_size, self.num_heads, self.sequence_length, self.head_size),
|
||||
"output": (self.batch_size, self.sequence_length, self.num_heads * self.head_size),
|
||||
}
|
||||
|
||||
if self.use_kv_cache:
|
||||
shapes = {
|
||||
"past_key": (self.batch_size, self.num_heads, self.past_buffer_length, self.head_size),
|
||||
"past_value": (self.batch_size, self.num_heads, self.past_buffer_length, self.head_size),
|
||||
"output": (self.batch_size, self.sequence_length, self.num_heads * self.head_size),
|
||||
"present_key": (self.batch_size, self.num_heads, self.present_buffer_length, self.head_size),
|
||||
"present_value": (self.batch_size, self.num_heads, self.present_buffer_length, self.head_size),
|
||||
}
|
||||
else:
|
||||
shapes = {
|
||||
"output": (self.batch_size, self.sequence_length, self.num_heads * self.head_size),
|
||||
}
|
||||
|
||||
if input_format == InputFormats.QKV_BSN3H:
|
||||
shapes.update({"query": (self.batch_size, self.sequence_length, self.num_heads, 3, self.head_size)})
|
||||
elif input_format == InputFormats.Q_KV_BSNH_BSN2H:
|
||||
shapes.update(
|
||||
{
|
||||
"query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size),
|
||||
"key": (self.batch_size, self.sequence_length, self.num_heads, 2, self.head_size),
|
||||
}
|
||||
)
|
||||
else: # input_format == InputFormats.Q_K_V_BSNH_BSNH_BSNH
|
||||
shapes.update(
|
||||
{
|
||||
"query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size),
|
||||
"key": (self.batch_size, self.sequence_length, self.num_heads * self.head_size),
|
||||
"value": (self.batch_size, self.sequence_length, self.num_heads * self.head_size),
|
||||
}
|
||||
)
|
||||
return shapes
|
||||
|
||||
def random_inputs(self, seed: int = 123):
|
||||
device = self.device
|
||||
dtype = self.dtype
|
||||
|
||||
shape_dict = self.shape_dict()
|
||||
|
||||
if seed > 0:
|
||||
torch.manual_seed(seed)
|
||||
|
||||
shape = (self.batch_size, self.sequence_length, self.num_heads, self.head_size)
|
||||
q = torch.empty(shape, device=device, dtype=dtype).normal_(mean=0, std=0.1)
|
||||
k = torch.empty(shape, device=device, dtype=dtype).normal_(mean=0, std=0.1)
|
||||
v = torch.empty(shape, device=device, dtype=dtype).normal_(mean=0, std=0.1)
|
||||
k_bnsh = k.transpose(1, 2)
|
||||
v_bnsh = v.transpose(1, 2)
|
||||
|
||||
if self.input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH:
|
||||
return {
|
||||
"query": q.reshape(shape_dict["query"]),
|
||||
"key": k_bnsh.contiguous(),
|
||||
"value": v_bnsh.contiguous(),
|
||||
}
|
||||
|
||||
feeds = {}
|
||||
if self.use_kv_cache:
|
||||
feeds.update(
|
||||
{
|
||||
"past_key": torch.empty(shape_dict["past_key"], device=device, dtype=dtype).normal_(
|
||||
mean=0, std=0.1
|
||||
),
|
||||
"past_value": torch.empty(shape_dict["past_value"], device=device, dtype=dtype).normal_(
|
||||
mean=0, std=0.1
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
if self.input_format == InputFormats.Q_K_V_BSNH_BSNH_BSNH:
|
||||
feeds.update(
|
||||
{
|
||||
"query": q.reshape(shape_dict["query"]),
|
||||
"key": k.reshape(shape_dict["key"]),
|
||||
"value": v.reshape(shape_dict["value"]),
|
||||
}
|
||||
)
|
||||
elif self.input_format == InputFormats.QKV_BSN3H:
|
||||
query = q.view(self.batch_size * self.sequence_length, self.num_heads, self.head_size)
|
||||
key = k.view(self.batch_size * self.sequence_length, self.num_heads, self.head_size)
|
||||
value = v.view(self.batch_size * self.sequence_length, self.num_heads, self.head_size)
|
||||
feeds["query"] = torch.dstack((query, key, value)).reshape(shape_dict["query"]).contiguous()
|
||||
elif self.input_format == InputFormats.Q_KV_BSNH_BSN2H:
|
||||
key = k.view(self.batch_size * self.sequence_length, self.num_heads, self.head_size)
|
||||
value = v.view(self.batch_size * self.sequence_length, self.num_heads, self.head_size)
|
||||
feeds["query"] = q.reshape(shape_dict["query"])
|
||||
feeds["key"] = torch.dstack((key, value)).reshape(shape_dict["key"]).contiguous()
|
||||
|
||||
return feeds
|
||||
|
||||
def get_input_output_names(self):
|
||||
if self.input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH:
|
||||
return ["query", "key"], ["output"]
|
||||
|
||||
if self.input_format == InputFormats.QKV_BSN3H:
|
||||
inputs, outputs = ["query"], ["output"]
|
||||
elif self.input_format == InputFormats.Q_KV_BSNH_BSN2H:
|
||||
inputs, outputs = ["query", "key"], ["output"]
|
||||
else:
|
||||
inputs, outputs = ["query", "key", "value"], ["output"]
|
||||
|
||||
if self.use_kv_cache:
|
||||
return [*input, "past_key", "past_value"], [*outputs, "present_key", "present_value"]
|
||||
else:
|
||||
return inputs, outputs
|
||||
|
||||
|
||||
def create_multihead_attention_graph(config: Config):
|
||||
query = helper.make_tensor_value_info(
|
||||
"query",
|
||||
TensorProto.FLOAT16,
|
||||
[
|
||||
config.batch_size,
|
||||
config.sequence_length,
|
||||
config.num_heads * config.head_size,
|
||||
],
|
||||
)
|
||||
|
||||
key = helper.make_tensor_value_info(
|
||||
"key",
|
||||
TensorProto.FLOAT16,
|
||||
[
|
||||
config.batch_size,
|
||||
config.kv_sequence_length,
|
||||
config.num_heads * config.head_size,
|
||||
],
|
||||
)
|
||||
|
||||
value = helper.make_tensor_value_info(
|
||||
"value",
|
||||
TensorProto.FLOAT16,
|
||||
[
|
||||
config.batch_size,
|
||||
config.kv_sequence_length,
|
||||
config.num_heads * config.head_size,
|
||||
],
|
||||
)
|
||||
|
||||
packed_qkv = helper.make_tensor_value_info(
|
||||
"query",
|
||||
TensorProto.FLOAT16,
|
||||
[
|
||||
config.batch_size,
|
||||
config.sequence_length,
|
||||
config.num_heads,
|
||||
3,
|
||||
config.head_size,
|
||||
],
|
||||
)
|
||||
|
||||
packed_kv = helper.make_tensor_value_info(
|
||||
"key",
|
||||
TensorProto.FLOAT16,
|
||||
[
|
||||
config.batch_size,
|
||||
config.kv_sequence_length,
|
||||
config.num_heads,
|
||||
2,
|
||||
config.head_size,
|
||||
],
|
||||
)
|
||||
|
||||
if config.input_format == InputFormats.QKV_BSN3H:
|
||||
input_names = ["query"]
|
||||
inputs = [packed_qkv]
|
||||
elif config.input_format == InputFormats.Q_KV_BSNH_BSN2H:
|
||||
input_names = ["query", "key"]
|
||||
inputs = [query, packed_kv]
|
||||
else: # input_format==InputFormats.Q_K_V_BSNH
|
||||
input_names = ["query", "key", "value"]
|
||||
inputs = [query, key, value]
|
||||
|
||||
def create_multi_head_attention_onnx_model(config: MultiHeadAttentionConfig):
|
||||
input_names, output_names = config.get_input_output_names()
|
||||
float_type = TensorProto.FLOAT16 if config.dtype == torch.float16 else TensorProto.FLOAT
|
||||
nodes = [
|
||||
helper.make_node(
|
||||
"MultiHeadAttention",
|
||||
input_names,
|
||||
["output"],
|
||||
output_names,
|
||||
"MultiHeadAttention_0",
|
||||
num_heads=config.num_heads,
|
||||
scale=config.softmax_scale,
|
||||
domain="com.microsoft",
|
||||
),
|
||||
]
|
||||
|
||||
shape_dict = config.shape_dict()
|
||||
inputs = [
|
||||
helper.make_tensor_value_info(input_name, float_type, list(shape_dict[input_name]))
|
||||
for input_name in input_names
|
||||
]
|
||||
outputs = [
|
||||
helper.make_tensor_value_info(
|
||||
"output",
|
||||
TensorProto.FLOAT16,
|
||||
[config.batch_size, config.sequence_length, config.num_heads * config.head_size],
|
||||
),
|
||||
helper.make_tensor_value_info(output_name, float_type, list(shape_dict[output_name]))
|
||||
for output_name in output_names
|
||||
]
|
||||
|
||||
graph = helper.make_graph(
|
||||
|
|
@ -144,41 +246,46 @@ def create_multihead_attention_graph(config: Config):
|
|||
return model.SerializeToString()
|
||||
|
||||
|
||||
def input_output_shapes(config: Config):
|
||||
if config.input_format == InputFormats.QKV_BSN3H:
|
||||
return {
|
||||
"query": (config.batch_size, config.sequence_length, config.num_heads, 3, config.head_size),
|
||||
"output": (config.batch_size, config.sequence_length, config.num_heads * config.head_size),
|
||||
}
|
||||
|
||||
if config.input_format == InputFormats.Q_KV_BSNH_BSN2H:
|
||||
return {
|
||||
"query": (config.batch_size, config.sequence_length, config.num_heads * config.head_size),
|
||||
"key": (config.batch_size, config.kv_sequence_length, config.num_heads, 2, config.head_size),
|
||||
"output": (config.batch_size, config.sequence_length, config.num_heads * config.head_size),
|
||||
}
|
||||
|
||||
return {
|
||||
"query": (config.batch_size, config.sequence_length, config.num_heads * config.head_size),
|
||||
"key": (config.batch_size, config.kv_sequence_length, config.num_heads * config.head_size),
|
||||
"value": (config.batch_size, config.kv_sequence_length, config.num_heads * config.head_size),
|
||||
"output": (config.batch_size, config.sequence_length, config.num_heads * config.head_size),
|
||||
}
|
||||
|
||||
|
||||
def create_session(
|
||||
device_id: int, config: Config, provider: str = "CUDAExecutionProvider", enable_cuda_graph: bool = False
|
||||
config: MultiHeadAttentionConfig,
|
||||
provider: str = "CUDAExecutionProvider",
|
||||
enable_cuda_graph: bool = False,
|
||||
device_id: int = 0,
|
||||
) -> CudaSession:
|
||||
onnx_model_str = create_multihead_attention_graph(config)
|
||||
provider_options = CudaSession.get_cuda_provider_options(device_id, enable_cuda_graph)
|
||||
ort_session = InferenceSession(onnx_model_str, providers=[(provider, provider_options), "CPUExecutionProvider"])
|
||||
device = torch.device("cuda", device_id)
|
||||
onnx_model_str = create_multi_head_attention_onnx_model(config)
|
||||
|
||||
if provider == "CUDAExecutionProvider":
|
||||
provider_options = CudaSession.get_cuda_provider_options(device_id, enable_cuda_graph)
|
||||
providers = [(provider, provider_options), "CPUExecutionProvider"]
|
||||
device = torch.device("cuda", device_id)
|
||||
else:
|
||||
providers = ["CPUExecutionProvider"]
|
||||
device = torch.device("cpu")
|
||||
|
||||
ort_session = InferenceSession(onnx_model_str, providers=providers)
|
||||
cuda_session = CudaSession(ort_session, device, enable_cuda_graph)
|
||||
shape_dict = input_output_shapes(config)
|
||||
shape_dict = config.shape_dict()
|
||||
cuda_session.allocate_buffers(shape_dict)
|
||||
return cuda_session
|
||||
|
||||
|
||||
class OrtMultiHeadAttention:
|
||||
"""A wrapper of ORT MultiHeadAttention to test relevance and performance."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MultiHeadAttentionConfig,
|
||||
provider: str = "CUDAExecutionProvider",
|
||||
enable_cuda_graph: bool = False,
|
||||
device_id: int = 0,
|
||||
):
|
||||
self.ort_session = create_session(config, provider, enable_cuda_graph=enable_cuda_graph, device_id=device_id)
|
||||
self.feed_dict = config.random_inputs()
|
||||
|
||||
def infer(self):
|
||||
return self.ort_session.infer(self.feed_dict)
|
||||
|
||||
|
||||
def measure_latency(cuda_session: CudaSession, input_dict):
|
||||
start = time.time()
|
||||
_ = cuda_session.infer(input_dict)
|
||||
|
|
@ -194,7 +301,7 @@ def tflops_per_second(flop, time):
|
|||
return (flop / time / 10**12) if not math.isnan(time) else 0.0
|
||||
|
||||
|
||||
def get_sm8x_kernel_name(config: Config) -> str:
|
||||
def get_gpu_kernel_name(config: MultiHeadAttentionConfig) -> str:
|
||||
# This classification is for Nvidia GPU of Compute Capability 8.* like A100.
|
||||
# Note that some kernel might not exist in older or newer GPUs.
|
||||
if os.getenv("ORT_DISABLE_FLASH_ATTENTION") != "1":
|
||||
|
|
@ -218,49 +325,72 @@ def get_sm8x_kernel_name(config: Config) -> str:
|
|||
return "Unfused"
|
||||
|
||||
|
||||
def run_tflops_test(dtype=torch.float16, enable_cuda_graph: bool = False, repeats: int = 100):
|
||||
device_id = torch.cuda.current_device()
|
||||
device = torch.device("cuda", device_id)
|
||||
def get_cpu_kernel_name() -> str:
|
||||
if os.getenv("ORT_DISABLE_FLASH_ATTENTION") != "1":
|
||||
return "CPU:Flash"
|
||||
return "CPU:Unfused"
|
||||
|
||||
# (batch_size, sequence_length, num_heads, head_size)
|
||||
configs = [
|
||||
(32, 512, 64, 32),
|
||||
(32, 512, 128, 16),
|
||||
(16, 1024, 64, 32),
|
||||
(16, 1024, 128, 16),
|
||||
(8, 2048, 64, 32),
|
||||
(8, 2048, 128, 16),
|
||||
(4, 4096, 64, 32),
|
||||
(4, 4096, 128, 16),
|
||||
(2, 8192, 64, 32),
|
||||
(2, 8192, 128, 16),
|
||||
(1, 16384, 64, 32),
|
||||
(1, 16384, 128, 16),
|
||||
# stable diffusion
|
||||
(1, 4096, 8, 40),
|
||||
(1, 4096, 8, 80),
|
||||
(1, 4096, 8, 160),
|
||||
(4, 4096, 8, 40),
|
||||
(4, 4096, 8, 80),
|
||||
(4, 4096, 8, 160),
|
||||
(1, 16384, 8, 40),
|
||||
(1, 16384, 8, 80),
|
||||
(1, 16384, 8, 160),
|
||||
# bert-base
|
||||
(128, 128, 12, 64),
|
||||
(64, 128, 12, 64),
|
||||
(128, 384, 12, 64),
|
||||
(64, 384, 12, 64),
|
||||
(128, 512, 12, 64),
|
||||
(64, 512, 12, 64),
|
||||
# TNLGv4
|
||||
(4, 2048, 32, 128),
|
||||
(4, 4096, 32, 128),
|
||||
(8, 2048, 32, 128),
|
||||
(8, 4096, 32, 128),
|
||||
]
|
||||
|
||||
print(f"enable_cuda_graph={enable_cuda_graph}")
|
||||
def run_tflops_test(use_gpu: bool = True, enable_cuda_graph: bool = False, repeats: int = 100):
|
||||
if use_gpu:
|
||||
device_id = torch.cuda.current_device()
|
||||
device = torch.device("cuda", device_id)
|
||||
formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH, InputFormats.Q_KV_BSNH_BSN2H, InputFormats.QKV_BSN3H]
|
||||
provider = "CUDAExecutionProvider"
|
||||
print(f"enable_cuda_graph={enable_cuda_graph}")
|
||||
else:
|
||||
device_id = 0
|
||||
device = torch.device("cpu")
|
||||
formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH]
|
||||
enable_cuda_graph = False
|
||||
provider = "CPUExecutionProvider"
|
||||
|
||||
if use_gpu:
|
||||
# (batch_size, sequence_length, past_sequence_length, num_heads, head_size, run_unfused)
|
||||
configs = [
|
||||
(32, 512, 0, 64, 32, True),
|
||||
(32, 512, 0, 128, 16, True),
|
||||
(16, 1024, 0, 64, 32, True),
|
||||
(16, 1024, 0, 128, 16, True),
|
||||
(8, 2048, 0, 64, 32, True),
|
||||
(8, 2048, 0, 128, 16, False),
|
||||
(4, 4096, 0, 64, 32, False),
|
||||
(4, 4096, 0, 128, 16, False),
|
||||
(2, 8192, 0, 64, 32, False),
|
||||
(2, 8192, 0, 128, 16, False),
|
||||
(1, 16384, 0, 64, 32, False),
|
||||
(1, 16384, 0, 128, 16, False),
|
||||
# stable diffusion
|
||||
(1, 4096, 0, 8, 40, False),
|
||||
(1, 4096, 0, 8, 80, False),
|
||||
(1, 4096, 0, 8, 160, False),
|
||||
(4, 4096, 0, 8, 40, False),
|
||||
(4, 4096, 0, 8, 80, False),
|
||||
(4, 4096, 0, 8, 160, False),
|
||||
(1, 16384, 0, 8, 40, False),
|
||||
(1, 16384, 0, 8, 80, False),
|
||||
(1, 16384, 0, 8, 160, False),
|
||||
# bert-base
|
||||
(128, 128, 0, 12, 64, True),
|
||||
(64, 128, 0, 12, 64, True),
|
||||
(128, 384, 0, 12, 64, True),
|
||||
(64, 384, 0, 12, 64, True),
|
||||
(128, 512, 0, 12, 64, True),
|
||||
(64, 512, 0, 12, 64, True),
|
||||
# TNLGv4
|
||||
(4, 2048, 0, 32, 128, True),
|
||||
(4, 4096, 0, 32, 128, False),
|
||||
(8, 2048, 0, 32, 128, False),
|
||||
(8, 4096, 0, 32, 128, False),
|
||||
]
|
||||
else:
|
||||
configs = [
|
||||
(1, 128, 0, 32, 128, True),
|
||||
(1, 256, 0, 32, 128, True),
|
||||
(1, 512, 0, 32, 128, True),
|
||||
(1, 1024, 0, 32, 128, True),
|
||||
(1, 2048, 0, 32, 128, True),
|
||||
]
|
||||
|
||||
# List of environment variables to enable/disable attention kernels
|
||||
print("Environment Variables:")
|
||||
|
|
@ -277,67 +407,176 @@ def run_tflops_test(dtype=torch.float16, enable_cuda_graph: bool = False, repeat
|
|||
value = os.getenv(name)
|
||||
if value is not None:
|
||||
print(f"{name}={value}")
|
||||
print()
|
||||
|
||||
print("format\tcausal\tbatch\tseqlen\theads\th_dim\tms\tTFLOPS\tkernel")
|
||||
print("\nformat\tcausal\tbatch\tseqlen\theads\th_dim\tms\tTFLOPS\tkernel")
|
||||
causal = False
|
||||
for input_format in [InputFormats.Q_K_V_BSNH, InputFormats.Q_KV_BSNH_BSN2H, InputFormats.QKV_BSN3H]:
|
||||
for batch_size, sequence_length, num_heads, head_size in configs:
|
||||
config = Config(batch_size, sequence_length, sequence_length, num_heads, head_size, causal, input_format)
|
||||
|
||||
session = create_session(device_id, config, enable_cuda_graph=enable_cuda_graph)
|
||||
|
||||
qkv = torch.randn(batch_size, sequence_length, 3, num_heads, head_size, device=device, dtype=dtype)
|
||||
q, k, v = qkv.unbind(dim=2)
|
||||
|
||||
if input_format == InputFormats.QKV_BSN3H:
|
||||
if config.sequence_length != config.kv_sequence_length:
|
||||
continue
|
||||
q = torch.reshape(q, (-1, config.num_heads, config.head_size))
|
||||
k = torch.reshape(k, (-1, config.num_heads, config.head_size))
|
||||
v = torch.reshape(v, (-1, config.num_heads, config.head_size))
|
||||
packed_qkv = torch.dstack((q, k, v)).reshape(
|
||||
config.batch_size, config.sequence_length, config.num_heads, 3, config.head_size
|
||||
for input_format in formats:
|
||||
for batch_size, sequence_length, past_sequence_length, num_heads, head_size, enable_unfused in configs:
|
||||
for use_kv_cache in [False]:
|
||||
config = MultiHeadAttentionConfig(
|
||||
batch_size=batch_size,
|
||||
sequence_length=sequence_length,
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
causal=True,
|
||||
use_kv_cache=use_kv_cache,
|
||||
past_sequence_length=past_sequence_length,
|
||||
max_cache_sequence_length=None,
|
||||
kv_sequence_length=None,
|
||||
device=device,
|
||||
dtype=torch.float16 if use_gpu else torch.float,
|
||||
share_past_present_buffer=False,
|
||||
input_format=input_format,
|
||||
)
|
||||
input_dict = {"query": packed_qkv.contiguous()}
|
||||
elif input_format == InputFormats.Q_KV_BSNH_BSN2H:
|
||||
q = torch.reshape(q, (config.batch_size, config.sequence_length, -1))
|
||||
k = torch.reshape(k, (-1, config.num_heads, config.head_size))
|
||||
v = torch.reshape(v, (-1, config.num_heads, config.head_size))
|
||||
packed_kv = torch.dstack((k, v)).reshape(
|
||||
config.batch_size, config.sequence_length, config.num_heads, 2, config.head_size
|
||||
|
||||
session = create_session(
|
||||
config, provider=provider, enable_cuda_graph=enable_cuda_graph, device_id=device_id
|
||||
)
|
||||
input_dict = {"query": q.contiguous(), "key": packed_kv.contiguous()}
|
||||
else: # input_format == InputFormats.Q_K_V_BSNH
|
||||
q = torch.reshape(q, (config.batch_size, config.sequence_length, -1))
|
||||
k = torch.reshape(k, (config.batch_size, config.kv_sequence_length, -1))
|
||||
v = torch.reshape(v, (config.batch_size, config.kv_sequence_length, -1))
|
||||
input_dict = {
|
||||
"query": q.contiguous(),
|
||||
"key": k.contiguous(),
|
||||
"value": v.contiguous(),
|
||||
}
|
||||
|
||||
# warm up session
|
||||
_ = measure_latency(session, input_dict)
|
||||
if use_gpu:
|
||||
kernel = get_gpu_kernel_name(config)
|
||||
else:
|
||||
kernel = get_cpu_kernel_name()
|
||||
|
||||
latency_list = []
|
||||
for _ in range(repeats):
|
||||
latency = measure_latency(session, input_dict)
|
||||
latency_list.append(latency)
|
||||
average_latency = statistics.mean(latency_list)
|
||||
if kernel == "Unfused":
|
||||
# Skip large sequence length for Unfused kernel to avoid OOM.
|
||||
if not enable_unfused:
|
||||
continue
|
||||
|
||||
del session
|
||||
# Unfused kernel does not support packed QKV or packed KV formats.
|
||||
if input_format not in [InputFormats.Q_K_V_BSNH_BSNH_BSNH]:
|
||||
continue
|
||||
|
||||
# compute TFLOPS per second
|
||||
speed = tflops_per_second(flops(batch_size, sequence_length, head_size, num_heads, causal), average_latency)
|
||||
input_dict = config.random_inputs()
|
||||
|
||||
kernel = get_sm8x_kernel_name(config)
|
||||
format = InputFormats.input_format_str(input_format)
|
||||
print(
|
||||
f"{format}\t{causal}\t{batch_size}\t{sequence_length}\t{num_heads}\t{head_size}\t{average_latency * 1000:.2f}\t{speed:.2f}\t{kernel}"
|
||||
)
|
||||
# warm up session
|
||||
_ = measure_latency(session, input_dict)
|
||||
|
||||
latency_list = []
|
||||
for _ in range(repeats):
|
||||
latency = measure_latency(session, input_dict)
|
||||
latency_list.append(latency)
|
||||
average_latency = statistics.mean(latency_list)
|
||||
|
||||
del session
|
||||
|
||||
# compute TFLOPS per second
|
||||
speed = tflops_per_second(
|
||||
flops(batch_size, sequence_length, head_size, num_heads, causal), average_latency
|
||||
)
|
||||
|
||||
format = InputFormats.input_format_str(input_format)
|
||||
print(
|
||||
f"{format}\t{causal}\t{batch_size}\t{sequence_length}\t{num_heads}\t{head_size}\t{average_latency * 1000:.2f}\t{speed:.2f}\t{kernel}"
|
||||
)
|
||||
|
||||
|
||||
def plot_prompt_performance(
|
||||
sm: int,
|
||||
model_name: str,
|
||||
batch_size: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
max_seq_len: int,
|
||||
):
|
||||
import triton
|
||||
|
||||
formats = InputFormats.get_name_list()
|
||||
|
||||
# Exclude cross attention since kernel crashes for some configuration.
|
||||
formats = formats[:-1]
|
||||
|
||||
settings = {
|
||||
"line_vals": formats,
|
||||
"line_names": ["ORT-MHA:" + name for name in formats],
|
||||
"styles": [("red", "solid"), ("yellow", "dashdot"), ("blue", "dashed"), ("green", "dotted")][0 : len(formats)],
|
||||
}
|
||||
|
||||
configs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names=["sequence_length"],
|
||||
x_vals=[2**i for i in range(6, 17) if 2**i <= max_seq_len],
|
||||
line_arg="input_format",
|
||||
ylabel="ms",
|
||||
**settings,
|
||||
plot_name=f"prompt-sm{sm}-{model_name}-b{batch_size}-h{num_heads}_{head_size}-fp16",
|
||||
args={
|
||||
"batch_size": batch_size,
|
||||
"num_heads": num_heads,
|
||||
"head_size": head_size,
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
def benchmark(
|
||||
input_format: str,
|
||||
sequence_length: int,
|
||||
batch_size: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
device="cuda",
|
||||
):
|
||||
warmup = 15
|
||||
repeat = 100
|
||||
|
||||
config: MultiHeadAttentionConfig = MultiHeadAttentionConfig(
|
||||
batch_size=batch_size,
|
||||
sequence_length=sequence_length,
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
causal=True,
|
||||
past_sequence_length=0,
|
||||
kv_sequence_length=sequence_length if input_format == InputFormats.get_name_list()[-1] else None,
|
||||
max_cache_sequence_length=max_seq_len,
|
||||
device=device,
|
||||
use_kv_cache=False,
|
||||
input_format=InputFormats.convert(input_format),
|
||||
)
|
||||
|
||||
obj = OrtMultiHeadAttention(config, enable_cuda_graph=False)
|
||||
ms = triton.testing.do_bench(obj.infer, warmup=warmup, rep=repeat)
|
||||
return ms
|
||||
|
||||
benchmark.run(save_path=".", print_data=True)
|
||||
|
||||
|
||||
def run_performance_test(sm: int):
|
||||
"""
|
||||
Run performance tests for prompt and token generation.
|
||||
|
||||
"""
|
||||
configures = [
|
||||
(1, 32, 128, 8192, "TNLGv4"),
|
||||
(4, 32, 128, 8192, "TNLGv4"),
|
||||
(1, 12, 64, 1024, "BertBase"),
|
||||
(16, 12, 64, 1024, "BertBase"),
|
||||
(1, 16, 64, 1024, "BertLarge"),
|
||||
(8, 16, 64, 1024, "BertLarge"),
|
||||
]
|
||||
|
||||
for batch_size, num_heads, head_size, max_seq_len, model_name in configures:
|
||||
plot_prompt_performance(
|
||||
sm=sm,
|
||||
batch_size=batch_size,
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
max_seq_len=max_seq_len,
|
||||
model_name=model_name,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tflops_test(enable_cuda_graph=False)
|
||||
if torch.cuda.is_available() and "CUDAExecutionProvider" in get_available_providers():
|
||||
# Test CUDA provider
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
sm = major * 10 + minor
|
||||
s = torch.cuda.Stream()
|
||||
with torch.cuda.stream(s), torch.no_grad():
|
||||
run_performance_test(sm)
|
||||
|
||||
run_tflops_test(use_gpu=True, enable_cuda_graph=True)
|
||||
|
||||
# Test CPU provider
|
||||
run_tflops_test(use_gpu=False, enable_cuda_graph=False)
|
||||
|
|
|
|||
Loading…
Reference in a new issue