mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-24 02:47:54 +00:00
[CUDA] Update benchmark_mha.py to capture debug info to identify sdpa kernel (#21804)
Use debug info to identify sdpa kernel actually used, and show it in the output of benchmark_mha.py. This updated benchmark script was used to get the benchmark results in https://github.com/microsoft/onnxruntime/pull/21629. (1) Change the output format of debug info to output like SdpaKernel=* (2) Add a step to capture stdout from onnxruntime session, and use regular expression to parse SdpaKernel=* from the captured text. Other minor changes: (1) Set different default repeats during benchmark: 100 for CPU; and 10000 for CUDA. (2) Fix PrintTensorByDims used in console dumper: if it is not enabled, do not dump tensor. (3) Update some comments ### Motivation and Context Sometime, we will use fallback for a sdpa_kernel. It could confuse user unless we can tell exact kernel is used in benchmark.
This commit is contained in:
parent
44a3923ba5
commit
25d7a4fa08
7 changed files with 121 additions and 54 deletions
|
|
@ -53,7 +53,11 @@ void PrintTensorByDims(const TConsoleDumper* dumper,
|
|||
const char* name,
|
||||
const T* tensor,
|
||||
gsl::span<const int64_t>& dims) {
|
||||
if (dumper->IsEnabled() && (tensor == nullptr || dims.size() == 0)) {
|
||||
if (!dumper->IsEnabled()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if ((tensor == nullptr || dims.size() == 0)) {
|
||||
std::cout << std::string(name) << " is None" << std::endl;
|
||||
return;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -128,45 +128,23 @@ void AttentionKernelDebugInfo::Print(const char* operator_name,
|
|||
sstream << " DataType=fp32";
|
||||
}
|
||||
|
||||
sstream << " SdpaKernel=";
|
||||
if (use_flash_attention.has_value() && use_flash_attention.value()) {
|
||||
sstream << " FLASH_ATTENTION=" << int(use_flash_attention.value());
|
||||
}
|
||||
|
||||
if (use_efficient_attention.has_value() && use_efficient_attention.value()) {
|
||||
sstream << " EFFICIENT_ATTENTION=" << int(use_efficient_attention.value());
|
||||
}
|
||||
|
||||
if (use_trt_fused_attention.has_value() && use_trt_fused_attention.value()) {
|
||||
sstream << " TRT_FUSED_ATTENTION=" << int(use_trt_fused_attention.value());
|
||||
}
|
||||
|
||||
if (use_cudnn_flash_attention.has_value() && use_cudnn_flash_attention.value()) {
|
||||
sstream << " CUDNN_FLASH_ATTENTION=" << int(use_cudnn_flash_attention.value());
|
||||
}
|
||||
|
||||
if (use_trt_flash_attention.has_value() && use_trt_flash_attention.value()) {
|
||||
sstream << " TRT_FLASH_ATTENTION=" << int(use_trt_flash_attention.value());
|
||||
}
|
||||
|
||||
if (use_trt_cross_attention.has_value() && use_trt_cross_attention.value()) {
|
||||
sstream << " TRT_CROSS_ATTENTION=" << int(use_trt_cross_attention.value());
|
||||
}
|
||||
|
||||
if (use_trt_causal_attention.has_value() && use_trt_causal_attention.value()) {
|
||||
sstream << " TRT_CAUSAL_ATTENTION=" << int(use_trt_causal_attention.value());
|
||||
}
|
||||
|
||||
bool use_fused = (use_flash_attention.has_value() && use_flash_attention.value()) ||
|
||||
(use_efficient_attention.has_value() && use_efficient_attention.value()) ||
|
||||
(use_trt_fused_attention.has_value() && use_trt_fused_attention.value()) ||
|
||||
(use_cudnn_flash_attention.has_value() && use_cudnn_flash_attention.value()) ||
|
||||
(use_trt_flash_attention.has_value() && use_trt_flash_attention.value()) ||
|
||||
(use_trt_cross_attention.has_value() && use_trt_cross_attention.value()) ||
|
||||
(use_trt_causal_attention.has_value() && use_trt_causal_attention.value());
|
||||
|
||||
// Fall back to unfused when no fused kernel is enabled.
|
||||
if (!use_fused) {
|
||||
sstream << " MATH=1";
|
||||
sstream << "FLASH_ATTENTION";
|
||||
} else if (use_efficient_attention.has_value() && use_efficient_attention.value()) {
|
||||
sstream << "EFFICIENT_ATTENTION";
|
||||
} else if (use_trt_fused_attention.has_value() && use_trt_fused_attention.value()) {
|
||||
sstream << "TRT_FUSED_ATTENTION";
|
||||
} else if (use_cudnn_flash_attention.has_value() && use_cudnn_flash_attention.value()) {
|
||||
sstream << "CUDNN_FLASH_ATTENTION";
|
||||
} else if (use_trt_flash_attention.has_value() && use_trt_flash_attention.value()) {
|
||||
sstream << "TRT_FLASH_ATTENTION";
|
||||
} else if (use_trt_cross_attention.has_value() && use_trt_cross_attention.value()) {
|
||||
sstream << "TRT_CROSS_ATTENTION";
|
||||
} else if (use_trt_causal_attention.has_value() && use_trt_causal_attention.value()) {
|
||||
sstream << "TRT_CAUSAL_ATTENTION";
|
||||
} else {
|
||||
sstream << "MATH";
|
||||
}
|
||||
|
||||
// Output text in Cyan color to make it easier to spot.
|
||||
|
|
|
|||
|
|
@ -314,7 +314,7 @@ struct BytesHash {
|
|||
};
|
||||
|
||||
// Use thread local caches because cuDNN execution plans are not guaranteed to be thread safe.
|
||||
// TODO(tianleiwu): since we the key includes sequence lengths, we may want to limit the cache size.
|
||||
// TODO(tianleiwu): since the key includes sequence lengths, we may want to limit the cache size.
|
||||
thread_local
|
||||
std::unordered_map<GraphParams, std::shared_ptr<fe::graph::Graph>, BytesHash<GraphParams> > mha_graph_cache;
|
||||
|
||||
|
|
|
|||
|
|
@ -233,7 +233,6 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
bool use_fused_runner =
|
||||
kernel_type == AttentionKernelType::AttentionKernel_Default &&
|
||||
!disable_fused_self_attention_ &&
|
||||
fused_cross_attention_kernel == nullptr &&
|
||||
nullptr == attention_bias &&
|
||||
(parameters.qkv_format == Q_K_V_BSNH || parameters.qkv_format == QKV_BSN3H) &&
|
||||
nullptr == past_key && nullptr == present_key &&
|
||||
|
|
|
|||
|
|
@ -111,7 +111,7 @@ Status PackedAttention<T>::CheckInputs(const TensorShape& input_shape,
|
|||
// Abbreviation and Meanings:
|
||||
// T: token_count
|
||||
// B: batch_size
|
||||
// S: sequence_length (input sequence length of query)
|
||||
// S: sequence_length
|
||||
// N: num_heads
|
||||
// H: head size for Q and K, aka q_head_size or v_head_size or qk_head_size
|
||||
// H_v: v_head_size
|
||||
|
|
@ -125,7 +125,7 @@ Status PackedAttention<T>::CheckInputs(const TensorShape& input_shape,
|
|||
// bias (Q/K/V) : (D + D + D_v)
|
||||
// token_offset : (B, S)
|
||||
// cu_seq_len_shape : (B + 1)
|
||||
// attention_bias : (B, N, S, S), (1, N, S, S) or NULL
|
||||
// attention_bias : (B or 1, N or 1, S, S) or NULL
|
||||
const auto& input_dims = input_shape.GetDims();
|
||||
if (input_dims.size() != 2) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ Status PackedMultiHeadAttention<T>::CheckInputs(const TensorShape& query_shape,
|
|||
// Input 'value': None
|
||||
// Input 'token_offset': (batch_size, sequence_length)
|
||||
// Input 'cumulative_sequence_length': (batch_size + 1)
|
||||
// Input 'attention_bias': (batch_size or 1, num_heads, sequence_length, sequence_length) or None
|
||||
// Input 'attention_bias': (batch_size or 1, num_heads or 1, sequence_length, sequence_length) or None
|
||||
// Output 'output': (token_count, v_hidden_size)
|
||||
|
||||
const auto& query_dims = query_shape.GetDims();
|
||||
|
|
|
|||
|
|
@ -18,7 +18,10 @@ import csv
|
|||
import math
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import statistics
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from datetime import datetime
|
||||
|
|
@ -771,6 +774,72 @@ def get_compute_capability():
|
|||
return sm
|
||||
|
||||
|
||||
class CaptureStdout:
|
||||
def __init__(self):
|
||||
self.fd = sys.stdout.fileno()
|
||||
self.chunk_size = 1024
|
||||
self.output = b""
|
||||
|
||||
def _capture(self):
|
||||
chunks = []
|
||||
while chunk := os.read(self._pipe_reader, self.chunk_size):
|
||||
chunks.append(chunk)
|
||||
self.output = b"".join(chunks)
|
||||
|
||||
def __enter__(self):
|
||||
self._duped_fd = os.dup(self.fd)
|
||||
self._pipe_reader, pipe_writer = os.pipe()
|
||||
os.dup2(pipe_writer, self.fd)
|
||||
os.close(pipe_writer)
|
||||
self._capture_thread = threading.Thread(target=self._capture)
|
||||
self._capture_thread.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
os.close(self.fd)
|
||||
self._capture_thread.join()
|
||||
os.close(self._pipe_reader)
|
||||
os.dup2(self._duped_fd, self.fd)
|
||||
os.close(self._duped_fd)
|
||||
|
||||
|
||||
def sdpa_kernel_from_debug_info(
|
||||
config: MultiHeadAttentionConfig, attention_kernel: SdpaKernel, sess_options: SessionOptions
|
||||
):
|
||||
os.environ["ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO"] = "1"
|
||||
captured_text = None
|
||||
try:
|
||||
with CaptureStdout() as captured:
|
||||
session = create_session(config, sess_options, attention_kernel=attention_kernel)
|
||||
input_dict = config.random_inputs()
|
||||
session.infer(input_dict)
|
||||
captured_text = captured.output.decode()
|
||||
except Exception as e:
|
||||
print(f"Failed to run {attention_kernel=} for {config=}. Exception: {e}")
|
||||
finally:
|
||||
os.environ["ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO"] = "0"
|
||||
|
||||
if captured_text is not None:
|
||||
m = re.search("SdpaKernel=(?P<kernel>[A-Z_]+)", captured_text)
|
||||
if m is not None:
|
||||
name = m.group("kernel")
|
||||
kernel_names = {
|
||||
"FLASH_ATTENTION": "ort:flash",
|
||||
"EFFICIENT_ATTENTION": "ort:efficient",
|
||||
"CUDNN_FLASH_ATTENTION": "ort:cudnn",
|
||||
"MATH": "ort:math",
|
||||
"TRT_FUSED_ATTENTION": "ort:trt_fmha",
|
||||
"TRT_FLASH_ATTENTION": "ort:trt_flash",
|
||||
"TRT_CROSS_ATTENTION": "ort:trt_cross",
|
||||
"TRT_CAUSAL_ATTENTION": "ort:trt_causal",
|
||||
}
|
||||
return kernel_names[name]
|
||||
else:
|
||||
print("Failed to get sdpa kernel from debug info:", captured_text)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def run_tflops_test(
|
||||
csv_writer: csv.DictWriter,
|
||||
args: argparse.Namespace,
|
||||
|
|
@ -809,7 +878,9 @@ def run_tflops_test(
|
|||
backends = [SdpaKernel.DEFAULT]
|
||||
|
||||
configs = get_test_configs(args)
|
||||
print("\nformat\tcausal\tattBias\tbatch\tseqlen\tpast\theads\th_dim\tthreads\tms\tTFLOPS\tkernel")
|
||||
print(
|
||||
"\nformat\tcausal\tattBias\tbatch\tseqlen\tpast\theads\th_dim\tthreads\tms\tTFLOPS\tsdpa_kernel\trequest_kernel"
|
||||
)
|
||||
|
||||
for input_format in formats:
|
||||
for batch_size, sequence_length, past_sequence_length, num_heads, head_size, enable_unfused in configs:
|
||||
|
|
@ -836,14 +907,13 @@ def run_tflops_test(
|
|||
for attention_kernel in backends:
|
||||
sess_options = SessionOptions()
|
||||
sess_options.intra_op_num_threads = intra_op_num_threads
|
||||
session = create_session(config, sess_options, attention_kernel=attention_kernel)
|
||||
|
||||
if use_gpu:
|
||||
kernel = get_gpu_kernel_name(attention_kernel)
|
||||
request_kernel = get_gpu_kernel_name(attention_kernel)
|
||||
else:
|
||||
kernel = get_cpu_kernel_name(config)
|
||||
request_kernel = get_cpu_kernel_name(config)
|
||||
|
||||
if "math" in kernel:
|
||||
if "math" in request_kernel:
|
||||
# Skip large sequence length for Unfused kernel to avoid OOM.
|
||||
if not enable_unfused:
|
||||
if config.verbose:
|
||||
|
|
@ -856,13 +926,23 @@ def run_tflops_test(
|
|||
print(f"skip input_format for {vars(config)}")
|
||||
continue
|
||||
|
||||
if use_gpu:
|
||||
actual_kernel = sdpa_kernel_from_debug_info(config, attention_kernel, sess_options)
|
||||
if actual_kernel is None:
|
||||
print(f"Warning: skip {config} since kernel from debug info is None")
|
||||
continue
|
||||
else:
|
||||
# CPU has no debug info for now.
|
||||
actual_kernel = request_kernel
|
||||
|
||||
session = create_session(config, sess_options, attention_kernel=attention_kernel)
|
||||
input_dict = config.random_inputs()
|
||||
|
||||
# warm up session
|
||||
try:
|
||||
_ = measure_latency(session, input_dict)
|
||||
except Exception as e:
|
||||
print(f"Failed to run {kernel=} for {config=}. Exception: {e}")
|
||||
print(f"Failed to run {request_kernel=} for {config=}. Exception: {e}")
|
||||
continue
|
||||
|
||||
latency_list = []
|
||||
|
|
@ -898,7 +978,8 @@ def run_tflops_test(
|
|||
"intra_op_num_threads": intra_op_num_threads,
|
||||
"average_latency": average_latency,
|
||||
"tflops": speed,
|
||||
"kernel": kernel,
|
||||
"request_kernel": request_kernel,
|
||||
"kernel": actual_kernel,
|
||||
}
|
||||
csv_writer.writerow(row)
|
||||
|
||||
|
|
@ -906,7 +987,7 @@ def run_tflops_test(
|
|||
print(
|
||||
f"{format_str}\t{causal}\t{args.has_attn_bias}\t{batch_size}\t"
|
||||
f"{sequence_length}\t{past_sequence_length}\t{num_heads}\t{head_size}\t"
|
||||
f"{intra_op_num_threads}\t{average_latency * 1000:.2f}\t{speed}\t{kernel}"
|
||||
f"{intra_op_num_threads}\t{average_latency * 1000:.2f}\t{speed}\t{actual_kernel}\t{request_kernel}"
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -979,7 +1060,7 @@ def run_torch_test(
|
|||
print(
|
||||
f"{input_format}\t{causal}\t{False}\t{batch_size}\t"
|
||||
f"{sequence_length}\t{past_sequence_length}\t{num_heads}\t{head_size}\t"
|
||||
f"{torch.get_num_threads()}\t{torch_latency * 1000:.2f}\t{speed}\t{backend_name}"
|
||||
f"{torch.get_num_threads()}\t{torch_latency * 1000:.2f}\t{speed}\t{backend_name}\t{backend_name}"
|
||||
)
|
||||
row = {
|
||||
"use_gpu": use_gpu,
|
||||
|
|
@ -997,6 +1078,7 @@ def run_torch_test(
|
|||
"intra_op_num_threads": torch.get_num_threads(),
|
||||
"average_latency": torch_latency,
|
||||
"tflops": speed,
|
||||
"request_kernel": backend_name,
|
||||
"kernel": backend_name,
|
||||
}
|
||||
csv_writer.writerow(row)
|
||||
|
|
@ -1030,6 +1112,7 @@ def run_tflops_tests(args):
|
|||
"intra_op_num_threads",
|
||||
"average_latency",
|
||||
"tflops",
|
||||
"request_kernel",
|
||||
"kernel",
|
||||
]
|
||||
csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
|
||||
|
|
@ -1224,7 +1307,7 @@ def _parse_arguments():
|
|||
"--repeats",
|
||||
required=False,
|
||||
type=int,
|
||||
default=100,
|
||||
default=0,
|
||||
help="number of repeats for performance test",
|
||||
)
|
||||
|
||||
|
|
@ -1269,6 +1352,9 @@ if __name__ == "__main__":
|
|||
args = _parse_arguments()
|
||||
print(f"arguments:{args}")
|
||||
|
||||
if args.repeats == 0:
|
||||
args.repeats = 10000 if args.use_gpu else 100
|
||||
|
||||
if args.use_gpu:
|
||||
assert args.torch or not args.causal, "no causal cuda kernel in MHA op"
|
||||
assert torch.cuda.is_available()
|
||||
|
|
|
|||
Loading…
Reference in a new issue