mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-10 00:38:54 +00:00
[ROCm] Make KE reports with better format (#17049)
This commit is contained in:
parent
0471f6fbb3
commit
b4e0fc87ea
10 changed files with 38 additions and 35 deletions
|
|
@ -138,14 +138,14 @@ class BatchedGemmMetric(ke.ComputeMetric):
|
|||
batch: int
|
||||
|
||||
def report(self):
|
||||
prefix = (
|
||||
f"{self.name:<50} {self.dtype} {transab_to_suffix((self.transa, self.transb))} "
|
||||
f"m={self.m:<4} n={self.n:<4} k={self.k:<4} batch={self.batch:<3} "
|
||||
common = (
|
||||
f"{self.dtype} {transab_to_suffix((self.transa, self.transb))} "
|
||||
f"m={self.m:<4} n={self.n:<4} k={self.k:<4} batch={self.batch:<3} {self.name}"
|
||||
)
|
||||
if self.duration <= 0:
|
||||
return prefix + "not supported"
|
||||
return "not supported " + common
|
||||
|
||||
return prefix + f"{self.duration:>8.4f} us {self.tflops:>5.2f} tflops"
|
||||
return f"{self.duration:>6.2f} us {self.tflops:>5.2f} tflops " + common
|
||||
|
||||
|
||||
def profile_gemm_func(f, dtype: str, transa: bool, transb: bool, m: int, n: int, k: int, batch: int):
|
||||
|
|
@ -173,8 +173,8 @@ def profile_gemm_func(f, dtype: str, transa: bool, transb: bool, m: int, n: int,
|
|||
duration_ms = -1
|
||||
if my_gemm.SelectOp(impl):
|
||||
duration_ms = my_gemm.Profile()
|
||||
FLOPs = batch * m * k * n * 2 # noqa: N806
|
||||
ke.report(BatchedGemmMetric(impl, dtype, duration_ms, FLOPs, transa, transb, m, n, k, batch))
|
||||
flops = batch * m * k * n * 2
|
||||
ke.report(BatchedGemmMetric(impl, dtype, duration_ms, flops, transa, transb, m, n, k, batch))
|
||||
|
||||
|
||||
def profile_with_args(dtype, transa, transb, m, n, k, batch, sort):
|
||||
|
|
|
|||
|
|
@ -84,10 +84,10 @@ class ElementwiseMetric(ke.BandwidthMetric):
|
|||
hidden_size: int
|
||||
|
||||
def report(self):
|
||||
prefix = f"{self.name:<50} {self.dtype} batch_size={self.batch_size:<4} seq_len={self.seq_len:<4} hidden_size={self.hidden_size:<4} "
|
||||
common = f"{self.dtype} batch_size={self.batch_size:<4} seq_len={self.seq_len:<4} hidden_size={self.hidden_size:<4} {self.name}"
|
||||
if self.duration > 0:
|
||||
return prefix + f"{self.duration:.2f} us, {self.gbps:.2f} GB/s"
|
||||
return prefix + "not supported or redundant"
|
||||
return f"{self.duration:>6.2f} us {self.gbps:>5.2f} GB/s " + common
|
||||
return "not supported " + common
|
||||
|
||||
|
||||
def profile_elementwise_func(batch_size, seq_len, hidden_size, dtype, func):
|
||||
|
|
|
|||
|
|
@ -113,13 +113,11 @@ class GemmFastGeluMetric(ke.ComputeMetric):
|
|||
k: int
|
||||
|
||||
def report(self):
|
||||
prefix = f"{self.name:<50} {self.dtype} {transab_to_suffix((self.transa, self.transb))} "
|
||||
if self.duration > 0:
|
||||
return (
|
||||
prefix
|
||||
+ f"m={self.m:<4} n={self.n:<4} k={self.k:<4} {self.duration:>8.4f} us {self.tflops:>5.2f} tflops"
|
||||
)
|
||||
return prefix + "not supported"
|
||||
transab = transab_to_suffix((self.transa, self.transb))
|
||||
common = f"{self.dtype} m={self.m:<4} n={self.n:<4} k={self.k:<4} {transab}, {self.name}"
|
||||
if self.duration <= 0:
|
||||
return "not supported " + common
|
||||
return f"{self.duration:>6.2f} us {self.tflops:>5.2f} tflops " + common
|
||||
|
||||
|
||||
def profile_gemmfastgelu_func(my_func, dtype: str, m: int, n: int, k: int, transa: bool, transb: bool):
|
||||
|
|
|
|||
|
|
@ -132,14 +132,14 @@ class GemmMetric(ke.ComputeMetric):
|
|||
k: int
|
||||
|
||||
def report(self):
|
||||
prefix = (
|
||||
f"{self.name:<50} {self.dtype} {transab_to_suffix((self.transa, self.transb))} "
|
||||
f"m={self.m:<4} n={self.n:<4} k={self.k:<4} "
|
||||
common = (
|
||||
f"{self.dtype} {transab_to_suffix((self.transa, self.transb))} "
|
||||
f"m={self.m:<4} n={self.n:<4} k={self.k:<4} {self.name}"
|
||||
)
|
||||
if self.duration <= 0:
|
||||
return prefix + "not supported"
|
||||
return "not supported " + common
|
||||
|
||||
return prefix + f"{self.duration:>8.4f} us {self.tflops:>5.2f} tflops"
|
||||
return f"{self.duration:>6.2f} us {self.tflops:>5.2f} tflops " + common
|
||||
|
||||
|
||||
def profile_gemm_func(f, dtype: str, transa: bool, transb: bool, m: int, n: int, k: int):
|
||||
|
|
|
|||
|
|
@ -131,7 +131,7 @@ class GroupNormNHWCMetric(ke.BandwidthMetric):
|
|||
f"num_channels={self.num_channels:<6} groups={self.groups:<4} {self.name}"
|
||||
)
|
||||
if self.duration > 0:
|
||||
return f"{self.duration:.2f} us, {self.gbps:.2f} GB/s " + common
|
||||
return f"{self.duration:>6.2f} us, {self.gbps:>5.2f} GB/s " + common
|
||||
return "not supported " + common
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -107,6 +107,11 @@ class BandwidthMetric(MetricBase):
|
|||
return self.bytes * 1e6 / self.duration / 1e9
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComputeAndBandwidthMetric(ComputeMetric, BandwidthMetric):
|
||||
pass
|
||||
|
||||
|
||||
class InstanceBenchmarkReporter:
|
||||
def __init__(self):
|
||||
self.sort = False
|
||||
|
|
|
|||
|
|
@ -109,10 +109,10 @@ class SkipLayerNormMetric(ke.BandwidthMetric):
|
|||
hidden_size: int
|
||||
|
||||
def report(self):
|
||||
prefix = f"{self.name:<50} {self.dtype} batch_size={self.batch_size:<4} seq_len={self.seq_len:<4} hidden_size={self.hidden_size:<4} "
|
||||
common = f"{self.dtype} batch_size={self.batch_size:<4} seq_len={self.seq_len:<4} hidden_size={self.hidden_size:<4} {self.name}"
|
||||
if self.duration > 0:
|
||||
return prefix + f"{self.duration:.2f} us, {self.gbps:.2f} GB/s"
|
||||
return prefix + "not supported or redundant"
|
||||
return f"{self.duration:6.2f} us, {self.gbps:5.2f} GB/s " + common
|
||||
return "not supported " + common
|
||||
|
||||
|
||||
def profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func, has_optional_output):
|
||||
|
|
|
|||
|
|
@ -76,10 +76,10 @@ class SoftmaxMetric(ke.BandwidthMetric):
|
|||
is_log_softmax: bool
|
||||
|
||||
def report(self):
|
||||
prefix = f"{self.name:<110} {self.dtype} batch_count={self.batch_count:<4} softmax_elements={self.softmax_elements:<4} is_log_softmax={self.is_log_softmax:<4}"
|
||||
common = f"{self.dtype} batch_count={self.batch_count:<4} softmax_elements={self.softmax_elements:<4} is_log_softmax={self.is_log_softmax:<4} {self.name}"
|
||||
if self.duration > 0:
|
||||
return prefix + f"{self.duration:.2f} us, {self.gbps:.2f} GB/s"
|
||||
return prefix + "not supported"
|
||||
return f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s " + common
|
||||
return "not supported " + common
|
||||
|
||||
|
||||
def profile_softmax_func(batch_count, softmax_elements, is_log_softmax, dtype, func):
|
||||
|
|
|
|||
|
|
@ -167,14 +167,14 @@ class StridedBatchedGemmMetric(ke.ComputeMetric):
|
|||
batch: int
|
||||
|
||||
def report(self):
|
||||
prefix = (
|
||||
f"{self.name:<50} {self.dtype} {transab_to_suffix((self.transa, self.transb))} "
|
||||
f"m={self.m:<4} n={self.n:<4} k={self.k:<4} batch={self.batch:<3} "
|
||||
common = (
|
||||
f"{self.dtype} {transab_to_suffix((self.transa, self.transb))} "
|
||||
f"m={self.m:<4} n={self.n:<4} k={self.k:<4} batch={self.batch:<3} {self.name}"
|
||||
)
|
||||
if self.duration <= 0:
|
||||
return prefix + "not supported"
|
||||
return "not supported " + common
|
||||
|
||||
return prefix + f"{self.duration:>8.4f} us {self.tflops:>5.2f} tflops"
|
||||
return f"{self.duration:>6.2f} us {self.tflops:>5.2f} tflops " + common
|
||||
|
||||
|
||||
def profile_gemm_func(f, dtype: str, transa: bool, transb: bool, m: int, n: int, k: int, batch: int):
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ class VectorAddMetric(ke.BandwidthMetric):
|
|||
size: int
|
||||
|
||||
def report(self):
|
||||
return f"{self.name :<50} {self.dtype} size={self.size:<4}, {self.duration:.2f} us, {self.gbps:.2f} GB/s"
|
||||
return f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s {self.dtype} size={self.size:<4} {self.name}"
|
||||
|
||||
|
||||
def profile_vector_add_func(size, dtype, func):
|
||||
|
|
|
|||
Loading…
Reference in a new issue