mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Allow fastgelu/skiplayernorm profile by pass args from commandline (#13025)
**Description**: Describe your changes. This allow us quickly launch a microbench session by, for example: `python skip_layer_norm_test.py 8 128 128 float32 `
This commit is contained in:
parent
32c2c4b480
commit
c26bb1bb19
3 changed files with 82 additions and 29 deletions
|
|
@ -3,11 +3,21 @@
|
|||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import sys
|
||||
from itertools import product
|
||||
|
||||
import kernel_explorer as ke
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
def get_bert_sizes():
|
||||
batch_sizes = [1]
|
||||
seq_lens = [384]
|
||||
hidden_sizes = [1024]
|
||||
return product(batch_sizes, seq_lens, hidden_sizes)
|
||||
|
||||
|
||||
def dtype_to_funcs(dtype):
|
||||
type_map = {
|
||||
"float16": list(filter(lambda x: "FastGelu_half" in x, dir(ke))),
|
||||
|
|
@ -42,17 +52,17 @@ def run_fast_gelu(x_size, bias_size, dtype, func):
|
|||
|
||||
|
||||
test_cases = [((2, 16), 16), ((1, 2, 768), 768), ((1, 2, 1024), 1024)]
|
||||
dtypes = ["float16", "float32", "float64"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("x_size, bias_size", test_cases)
|
||||
def test_fast_gelu(x_size, bias_size):
|
||||
dtypes = ["float16", "float32", "float64"]
|
||||
for dtype in dtypes:
|
||||
for f in dtype_to_funcs(dtype):
|
||||
run_fast_gelu(x_size, bias_size, dtype, f)
|
||||
@pytest.mark.parametrize("dtype", dtypes)
|
||||
def test_fast_gelu(x_size, bias_size, dtype):
|
||||
for f in dtype_to_funcs(dtype):
|
||||
run_fast_gelu(x_size, bias_size, dtype, f)
|
||||
|
||||
|
||||
def profile_vector_add_func(batch_size, seq_len, hidden_size, dtype, func):
|
||||
def profile_fast_gelu_func(batch_size, seq_len, hidden_size, dtype, func):
|
||||
x_size = [batch_size, seq_len, hidden_size * 3]
|
||||
bias_size = hidden_size * 3
|
||||
np.random.seed(0)
|
||||
|
|
@ -77,19 +87,29 @@ def profile_vector_add_func(batch_size, seq_len, hidden_size, dtype, func):
|
|||
)
|
||||
|
||||
|
||||
def profile_with_args(batch_size, seq_len, hidden_size, dtype):
|
||||
for func in dtype_to_funcs(dtype):
|
||||
profile_fast_gelu_func(batch_size, seq_len, hidden_size, dtype, func)
|
||||
print()
|
||||
|
||||
|
||||
def profile():
|
||||
batch_size = [1]
|
||||
seq_len = [384]
|
||||
hidden_size = [1024]
|
||||
dtypes = ["float16", "float32", "float64"]
|
||||
for dt in dtypes:
|
||||
for bs in batch_size:
|
||||
for sl in seq_len:
|
||||
for hs in hidden_size:
|
||||
for f in dtype_to_funcs(dt):
|
||||
profile_vector_add_func(bs, sl, hs, dt, f)
|
||||
print()
|
||||
for dtype in dtypes:
|
||||
for bert_size in get_bert_sizes():
|
||||
profile_with_args(*bert_size, dtype)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
profile()
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
group = parser.add_argument_group("profile with args")
|
||||
group.add_argument("batch_size", type=int)
|
||||
group.add_argument("seq_len", type=int)
|
||||
group.add_argument("hidden_size", type=int)
|
||||
group.add_argument("dtype", choices=dtypes)
|
||||
if len(sys.argv) == 1:
|
||||
profile()
|
||||
else:
|
||||
args = parser.parse_args()
|
||||
profile_with_args(args.batch_size, args.seq_len, args.hidden_size, args.dtype)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
# --------------------------------------------------------------------------
|
||||
|
||||
import re
|
||||
import sys
|
||||
from itertools import product
|
||||
|
||||
import kernel_explorer as ke
|
||||
|
|
@ -116,13 +117,29 @@ def profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func):
|
|||
)
|
||||
|
||||
|
||||
def profile_with_args(batch_size, seq_len, hidden_size, dtype):
|
||||
for func in dtype_to_funcs(dtype):
|
||||
profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func)
|
||||
print()
|
||||
|
||||
|
||||
def profile():
|
||||
for dtype in dtypes:
|
||||
for bert_size in get_bert_sizes_profile():
|
||||
for func in dtype_to_funcs(dtype):
|
||||
profile_skip_layer_norm_func(*bert_size, dtype, func)
|
||||
print()
|
||||
profile_with_args(*bert_size, dtype)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
profile()
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
group = parser.add_argument_group("profile with args")
|
||||
group.add_argument("batch_size", type=int)
|
||||
group.add_argument("seq_len", type=int)
|
||||
group.add_argument("hidden_size", type=int)
|
||||
group.add_argument("dtype", choices=dtypes)
|
||||
if len(sys.argv) == 1:
|
||||
profile()
|
||||
else:
|
||||
args = parser.parse_args()
|
||||
profile_with_args(args.batch_size, args.seq_len, args.hidden_size, args.dtype)
|
||||
|
|
|
|||
|
|
@ -46,9 +46,12 @@ def run_vector_add(size, dtype, func):
|
|||
np.testing.assert_allclose(z_ref, z)
|
||||
|
||||
|
||||
dtypes = ["float16", "float32"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("size", [1, 3, 4, 16, 124, 125, 126, 127, 128, 129, 130, 131, 132, 1024])
|
||||
def test_vector_add(size):
|
||||
dtypes = ["float16", "float32"]
|
||||
@pytest.mark.parametrize("dtype", dtypes)
|
||||
def test_vector_add(size, dtype):
|
||||
for dtype in dtypes:
|
||||
for f in dtype_to_funcs(dtype):
|
||||
run_vector_add(size, dtype, f)
|
||||
|
|
@ -69,16 +72,29 @@ def profile_vector_add_func(size, dtype, func):
|
|||
print(dtype, size, f, f"{t*1000:.2f} us", f"{size*3*(dtype_to_bytes(dtype))*1e3/t/1e9:.2f} GB/s")
|
||||
|
||||
|
||||
def profile_with_args(size, dtype):
|
||||
for func in dtype_to_funcs(dtype):
|
||||
profile_vector_add_func(size, dtype, func)
|
||||
print()
|
||||
|
||||
|
||||
def profile():
|
||||
sizes = [10000, 100000, 1000000, 10000000]
|
||||
dtypes = ["float16", "float32"]
|
||||
for dt in dtypes:
|
||||
for s in sizes:
|
||||
for f in dtype_to_funcs(dt):
|
||||
profile_vector_add_func(s, dt, f)
|
||||
print()
|
||||
profile_with_args(s, dt)
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
profile()
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
group = parser.add_argument_group("profile with args")
|
||||
group.add_argument("size", type=int)
|
||||
group.add_argument("dtype", choices=dtypes)
|
||||
if len(sys.argv) == 1:
|
||||
profile()
|
||||
else:
|
||||
args = parser.parse_args()
|
||||
profile_with_args(args.size, args.dtype)
|
||||
|
|
|
|||
Loading…
Reference in a new issue