[ROCm] Make KE reports with better format (#17049)

This commit is contained in:
cloudhan 2023-08-10 17:44:32 +08:00 committed by GitHub
parent 0471f6fbb3
commit b4e0fc87ea
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 38 additions and 35 deletions

View file

@ -138,14 +138,14 @@ class BatchedGemmMetric(ke.ComputeMetric):
batch: int
def report(self):
prefix = (
f"{self.name:<50} {self.dtype} {transab_to_suffix((self.transa, self.transb))} "
f"m={self.m:<4} n={self.n:<4} k={self.k:<4} batch={self.batch:<3} "
common = (
f"{self.dtype} {transab_to_suffix((self.transa, self.transb))} "
f"m={self.m:<4} n={self.n:<4} k={self.k:<4} batch={self.batch:<3} {self.name}"
)
if self.duration <= 0:
return prefix + "not supported"
return "not supported " + common
return prefix + f"{self.duration:>8.4f} us {self.tflops:>5.2f} tflops"
return f"{self.duration:>6.2f} us {self.tflops:>5.2f} tflops " + common
def profile_gemm_func(f, dtype: str, transa: bool, transb: bool, m: int, n: int, k: int, batch: int):
@ -173,8 +173,8 @@ def profile_gemm_func(f, dtype: str, transa: bool, transb: bool, m: int, n: int,
duration_ms = -1
if my_gemm.SelectOp(impl):
duration_ms = my_gemm.Profile()
FLOPs = batch * m * k * n * 2 # noqa: N806
ke.report(BatchedGemmMetric(impl, dtype, duration_ms, FLOPs, transa, transb, m, n, k, batch))
flops = batch * m * k * n * 2
ke.report(BatchedGemmMetric(impl, dtype, duration_ms, flops, transa, transb, m, n, k, batch))
def profile_with_args(dtype, transa, transb, m, n, k, batch, sort):

View file

@ -84,10 +84,10 @@ class ElementwiseMetric(ke.BandwidthMetric):
hidden_size: int
def report(self):
prefix = f"{self.name:<50} {self.dtype} batch_size={self.batch_size:<4} seq_len={self.seq_len:<4} hidden_size={self.hidden_size:<4} "
common = f"{self.dtype} batch_size={self.batch_size:<4} seq_len={self.seq_len:<4} hidden_size={self.hidden_size:<4} {self.name}"
if self.duration > 0:
return prefix + f"{self.duration:.2f} us, {self.gbps:.2f} GB/s"
return prefix + "not supported or redundant"
return f"{self.duration:>6.2f} us {self.gbps:>5.2f} GB/s " + common
return "not supported " + common
def profile_elementwise_func(batch_size, seq_len, hidden_size, dtype, func):

View file

@ -113,13 +113,11 @@ class GemmFastGeluMetric(ke.ComputeMetric):
k: int
def report(self):
prefix = f"{self.name:<50} {self.dtype} {transab_to_suffix((self.transa, self.transb))} "
if self.duration > 0:
return (
prefix
+ f"m={self.m:<4} n={self.n:<4} k={self.k:<4} {self.duration:>8.4f} us {self.tflops:>5.2f} tflops"
)
return prefix + "not supported"
transab = transab_to_suffix((self.transa, self.transb))
common = f"{self.dtype} m={self.m:<4} n={self.n:<4} k={self.k:<4} {transab}, {self.name}"
if self.duration <= 0:
return "not supported " + common
return f"{self.duration:>6.2f} us {self.tflops:>5.2f} tflops " + common
def profile_gemmfastgelu_func(my_func, dtype: str, m: int, n: int, k: int, transa: bool, transb: bool):

View file

@ -132,14 +132,14 @@ class GemmMetric(ke.ComputeMetric):
k: int
def report(self):
prefix = (
f"{self.name:<50} {self.dtype} {transab_to_suffix((self.transa, self.transb))} "
f"m={self.m:<4} n={self.n:<4} k={self.k:<4} "
common = (
f"{self.dtype} {transab_to_suffix((self.transa, self.transb))} "
f"m={self.m:<4} n={self.n:<4} k={self.k:<4} {self.name}"
)
if self.duration <= 0:
return prefix + "not supported"
return "not supported " + common
return prefix + f"{self.duration:>8.4f} us {self.tflops:>5.2f} tflops"
return f"{self.duration:>6.2f} us {self.tflops:>5.2f} tflops " + common
def profile_gemm_func(f, dtype: str, transa: bool, transb: bool, m: int, n: int, k: int):

View file

@ -131,7 +131,7 @@ class GroupNormNHWCMetric(ke.BandwidthMetric):
f"num_channels={self.num_channels:<6} groups={self.groups:<4} {self.name}"
)
if self.duration > 0:
return f"{self.duration:.2f} us, {self.gbps:.2f} GB/s " + common
return f"{self.duration:>6.2f} us, {self.gbps:>5.2f} GB/s " + common
return "not supported " + common

View file

@ -107,6 +107,11 @@ class BandwidthMetric(MetricBase):
return self.bytes * 1e6 / self.duration / 1e9
@dataclass
class ComputeAndBandwidthMetric(ComputeMetric, BandwidthMetric):
pass
class InstanceBenchmarkReporter:
def __init__(self):
self.sort = False

View file

@ -109,10 +109,10 @@ class SkipLayerNormMetric(ke.BandwidthMetric):
hidden_size: int
def report(self):
prefix = f"{self.name:<50} {self.dtype} batch_size={self.batch_size:<4} seq_len={self.seq_len:<4} hidden_size={self.hidden_size:<4} "
common = f"{self.dtype} batch_size={self.batch_size:<4} seq_len={self.seq_len:<4} hidden_size={self.hidden_size:<4} {self.name}"
if self.duration > 0:
return prefix + f"{self.duration:.2f} us, {self.gbps:.2f} GB/s"
return prefix + "not supported or redundant"
return f"{self.duration:6.2f} us, {self.gbps:5.2f} GB/s " + common
return "not supported " + common
def profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func, has_optional_output):

View file

@ -76,10 +76,10 @@ class SoftmaxMetric(ke.BandwidthMetric):
is_log_softmax: bool
def report(self):
prefix = f"{self.name:<110} {self.dtype} batch_count={self.batch_count:<4} softmax_elements={self.softmax_elements:<4} is_log_softmax={self.is_log_softmax:<4}"
common = f"{self.dtype} batch_count={self.batch_count:<4} softmax_elements={self.softmax_elements:<4} is_log_softmax={self.is_log_softmax:<4} {self.name}"
if self.duration > 0:
return prefix + f"{self.duration:.2f} us, {self.gbps:.2f} GB/s"
return prefix + "not supported"
return f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s " + common
return "not supported " + common
def profile_softmax_func(batch_count, softmax_elements, is_log_softmax, dtype, func):

View file

@ -167,14 +167,14 @@ class StridedBatchedGemmMetric(ke.ComputeMetric):
batch: int
def report(self):
prefix = (
f"{self.name:<50} {self.dtype} {transab_to_suffix((self.transa, self.transb))} "
f"m={self.m:<4} n={self.n:<4} k={self.k:<4} batch={self.batch:<3} "
common = (
f"{self.dtype} {transab_to_suffix((self.transa, self.transb))} "
f"m={self.m:<4} n={self.n:<4} k={self.k:<4} batch={self.batch:<3} {self.name}"
)
if self.duration <= 0:
return prefix + "not supported"
return "not supported " + common
return prefix + f"{self.duration:>8.4f} us {self.tflops:>5.2f} tflops"
return f"{self.duration:>6.2f} us {self.tflops:>5.2f} tflops " + common
def profile_gemm_func(f, dtype: str, transa: bool, transb: bool, m: int, n: int, k: int, batch: int):

View file

@ -54,7 +54,7 @@ class VectorAddMetric(ke.BandwidthMetric):
size: int
def report(self):
return f"{self.name :<50} {self.dtype} size={self.size:<4}, {self.duration:.2f} us, {self.gbps:.2f} GB/s"
return f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s {self.dtype} size={self.size:<4} {self.name}"
def profile_vector_add_func(size, dtype, func):