diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/fast_gelu_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/fast_gelu_test.py index 6a9887f589..cd74e80ea1 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/fast_gelu_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/fast_gelu_test.py @@ -3,11 +3,21 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- +import sys +from itertools import product + import kernel_explorer as ke import numpy as np import pytest +def get_bert_sizes(): + batch_sizes = [1] + seq_lens = [384] + hidden_sizes = [1024] + return product(batch_sizes, seq_lens, hidden_sizes) + + def dtype_to_funcs(dtype): type_map = { "float16": list(filter(lambda x: "FastGelu_half" in x, dir(ke))), @@ -42,17 +52,17 @@ def run_fast_gelu(x_size, bias_size, dtype, func): test_cases = [((2, 16), 16), ((1, 2, 768), 768), ((1, 2, 1024), 1024)] +dtypes = ["float16", "float32", "float64"] @pytest.mark.parametrize("x_size, bias_size", test_cases) -def test_fast_gelu(x_size, bias_size): - dtypes = ["float16", "float32", "float64"] - for dtype in dtypes: - for f in dtype_to_funcs(dtype): - run_fast_gelu(x_size, bias_size, dtype, f) +@pytest.mark.parametrize("dtype", dtypes) +def test_fast_gelu(x_size, bias_size, dtype): + for f in dtype_to_funcs(dtype): + run_fast_gelu(x_size, bias_size, dtype, f) -def profile_vector_add_func(batch_size, seq_len, hidden_size, dtype, func): +def profile_fast_gelu_func(batch_size, seq_len, hidden_size, dtype, func): x_size = [batch_size, seq_len, hidden_size * 3] bias_size = hidden_size * 3 np.random.seed(0) @@ -77,19 +87,29 @@ def profile_vector_add_func(batch_size, seq_len, hidden_size, dtype, func): ) +def profile_with_args(batch_size, seq_len, hidden_size, dtype): + for func in dtype_to_funcs(dtype): + profile_fast_gelu_func(batch_size, seq_len, hidden_size, dtype, func) + print() + + def profile(): - batch_size = [1] - seq_len = [384] - hidden_size = [1024] - dtypes = ["float16", "float32", "float64"] - for dt in dtypes: - for bs in batch_size: - for sl in seq_len: - for hs in hidden_size: - for f in dtype_to_funcs(dt): - profile_vector_add_func(bs, sl, hs, dt, f) - print() + for dtype in dtypes: + for bert_size in get_bert_sizes(): + profile_with_args(*bert_size, dtype) if __name__ == "__main__": - profile() + import argparse + + parser = argparse.ArgumentParser() + group = parser.add_argument_group("profile with args") + group.add_argument("batch_size", type=int) + group.add_argument("seq_len", type=int) + group.add_argument("hidden_size", type=int) + group.add_argument("dtype", choices=dtypes) + if len(sys.argv) == 1: + profile() + else: + args = parser.parse_args() + profile_with_args(args.batch_size, args.seq_len, args.hidden_size, args.dtype) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/skip_layer_norm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/skip_layer_norm_test.py index c12cb1fc63..c32099ff28 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/skip_layer_norm_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/skip_layer_norm_test.py @@ -4,6 +4,7 @@ # -------------------------------------------------------------------------- import re +import sys from itertools import product import kernel_explorer as ke @@ -116,13 +117,29 @@ def profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func): ) +def profile_with_args(batch_size, seq_len, hidden_size, dtype): + for func in dtype_to_funcs(dtype): + profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func) + print() + + def profile(): for dtype in dtypes: for bert_size in get_bert_sizes_profile(): - for func in dtype_to_funcs(dtype): - profile_skip_layer_norm_func(*bert_size, dtype, func) - print() + profile_with_args(*bert_size, dtype) if __name__ == "__main__": - profile() + import argparse + + parser = argparse.ArgumentParser() + group = parser.add_argument_group("profile with args") + group.add_argument("batch_size", type=int) + group.add_argument("seq_len", type=int) + group.add_argument("hidden_size", type=int) + group.add_argument("dtype", choices=dtypes) + if len(sys.argv) == 1: + profile() + else: + args = parser.parse_args() + profile_with_args(args.batch_size, args.seq_len, args.hidden_size, args.dtype) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/vector_add_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/vector_add_test.py index dcd518be4e..31899056f7 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/vector_add_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/vector_add_test.py @@ -46,9 +46,12 @@ def run_vector_add(size, dtype, func): np.testing.assert_allclose(z_ref, z) +dtypes = ["float16", "float32"] + + @pytest.mark.parametrize("size", [1, 3, 4, 16, 124, 125, 126, 127, 128, 129, 130, 131, 132, 1024]) -def test_vector_add(size): - dtypes = ["float16", "float32"] +@pytest.mark.parametrize("dtype", dtypes) +def test_vector_add(size, dtype): for dtype in dtypes: for f in dtype_to_funcs(dtype): run_vector_add(size, dtype, f) @@ -69,16 +72,29 @@ def profile_vector_add_func(size, dtype, func): print(dtype, size, f, f"{t*1000:.2f} us", f"{size*3*(dtype_to_bytes(dtype))*1e3/t/1e9:.2f} GB/s") +def profile_with_args(size, dtype): + for func in dtype_to_funcs(dtype): + profile_vector_add_func(size, dtype, func) + print() + + def profile(): sizes = [10000, 100000, 1000000, 10000000] - dtypes = ["float16", "float32"] for dt in dtypes: for s in sizes: - for f in dtype_to_funcs(dt): - profile_vector_add_func(s, dt, f) - print() + profile_with_args(s, dt) print() if __name__ == "__main__": - profile() + import argparse + + parser = argparse.ArgumentParser() + group = parser.add_argument_group("profile with args") + group.add_argument("size", type=int) + group.add_argument("dtype", choices=dtypes) + if len(sys.argv) == 1: + profile() + else: + args = parser.parse_args() + profile_with_args(args.size, args.dtype)