From bf2cc808a180da4a02147271a4691fcbd83055bb Mon Sep 17 00:00:00 2001 From: mindest <30493312+mindest@users.noreply.github.com> Date: Thu, 9 Mar 2023 22:27:01 +0800 Subject: [PATCH] [ROCm] SkipLayerNorm: add more configs for block size; loosen constraints (#14900) ### Description * add more configs for `threads_per_block` in SkipLayerNorm, also in kernel explorer. * loosen constraints for hidden_size, so that `SkipLayerNormSmallOp` can be selected for larger hidden sizes. * add flag for optional output in kernel_explorer ### Motivation and Context --- .../rocm/bert/skip_layer_norm_tunable_op.h | 17 +++++++++-- .../kernels/rocm/skip_layer_norm.cu | 11 ++++++- .../kernels/skip_layer_norm_test.py | 30 +++++++++++++------ 3 files changed, 45 insertions(+), 13 deletions(-) diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h index b8d0dfee74..b37ed32563 100644 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h @@ -23,7 +23,7 @@ struct SkipLayerNormParams : OpParams { SkipLayerNormParams(RocmTuningContext* tuning_ctx, hipStream_t stream, T* output, T* skip_input_bias_add_output, const T* input, const T* skip, const T* gamma, const T* beta, const T* bias, float epsilon, int ld, int element_count) - : OpParams(tuning_ctx, stream), output(output), skip_input_bias_add_output(skip_input_bias_add_output), input(input), skip(skip), + : OpParams(tuning_ctx, stream), output(output), skip_input_bias_add_output(skip_input_bias_add_output), input(input), skip(skip), gamma(gamma), beta(beta), bias(bias), epsilon(epsilon), ld(ld), element_count(element_count) {} std::string Signature() const override { @@ -45,8 +45,10 @@ struct SkipLayerNormParams : OpParams { template Status SkipLayerNormSmallOp(const SkipLayerNormParams* params) { + // Loosen the hard constraint for ld (hidden_size) to include more possible *Small kernels, + // which could offer better performance in some combinations of ThreadsPerBlock and VecSize. TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !((params->ld <= 1024 && params->ld % VecSize == 0 && + !((params->ld <= 8192 && params->ld % VecSize == 0 && params->ld <= ThreadsPerBlock * VecSize && params->ld > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize))); SkipLayerNormKernelSmall<<element_count, params->ld)), dim3(ThreadsPerBlock), @@ -159,7 +161,16 @@ Status SkipLayerNormStaticSelection(const SkipLayerNormParams* params) { ADD_OP_FOR_ALL_VEC_SIZE(name, 192) \ ADD_OP_FOR_ALL_VEC_SIZE(name, 256) \ ADD_OP_FOR_ALL_VEC_SIZE(name, 320) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 384) + ADD_OP_FOR_ALL_VEC_SIZE(name, 384) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 448) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 512) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 576) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 640) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 704) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 768) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 832) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 896) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 1024) template class SkipLayerNormTunableOp : public TunableOp> { diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.cu index 21ffe83293..52f1f5c1b1 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.cu @@ -136,7 +136,16 @@ class SkipLayerNormTunable : public IKernelExplorer { REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 192) \ REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 256) \ REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 320) \ - REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 384) + REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 384) \ + REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 448) \ + REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 512) \ + REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 576) \ + REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 640) \ + REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 704) \ + REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 768) \ + REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 832) \ + REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 896) \ + REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 1024) #define REGISTER_OP_TYPED(name, type) \ py::class_>(m, #name "_" #type) \ diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/skip_layer_norm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/skip_layer_norm_test.py index 62cc1e4f89..2a653f92a4 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/skip_layer_norm_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/skip_layer_norm_test.py @@ -46,7 +46,7 @@ def skip_layer_norm(input_x, skip, bias, gamma, beta, epsilon): return output, val -def run_skip_layer_norm(batch_size: int, seq_len: int, hidden_size: int, dtype: str, func): +def run_skip_layer_norm(batch_size: int, seq_len: int, hidden_size: int, dtype: str, func, has_optional_output=False): np.random.seed(0) input_x = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype) skip = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype) @@ -56,7 +56,11 @@ def run_skip_layer_norm(batch_size: int, seq_len: int, hidden_size: int, dtype: # Because of rocm FMAs calculation issue with float16, epsilon should be larger when hidden_size is small epsilon = 0.05 if hidden_size < 8 else 0.0005 output_y = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype) - output_optional = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype) + output_optional = ( + np.random.rand(batch_size, seq_len, hidden_size).astype(dtype) + if has_optional_output + else np.empty((0), dtype=dtype) + ) input_d = ke.DeviceArray(input_x) skip_d = ke.DeviceArray(skip) @@ -86,8 +90,9 @@ def run_skip_layer_norm(batch_size: int, seq_len: int, hidden_size: int, dtype: optional_d.UpdateHostNumpyArray() y_ref, y_optional = skip_layer_norm(input_x, skip, bias, gamma, beta, epsilon) - np.testing.assert_almost_equal(y_ref, output_y, decimal=1e-05) - np.testing.assert_almost_equal(y_optional, output_optional, decimal=1e-05) + np.testing.assert_almost_equal(y_ref, output_y, decimal=1) + if has_optional_output: + np.testing.assert_almost_equal(y_optional, output_optional, decimal=1) dtypes = ["float32", "float16"] @@ -113,7 +118,7 @@ class SkipLayerNormMetric(ke.BandwidthMetric): return prefix + "not supported or redundant" -def profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func): +def profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func, has_optional_output): np.random.seed(0) input_x = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype) skip = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype) @@ -122,7 +127,11 @@ def profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func): bias = np.random.rand(hidden_size).astype(dtype) epsilon = 0.0005 output_y = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype) - output_optional = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype) + output_optional = ( + np.random.rand(batch_size, seq_len, hidden_size).astype(dtype) + if has_optional_output + else np.empty((0), dtype=dtype) + ) input_d = ke.DeviceArray(input_x) skip_d = ke.DeviceArray(skip) @@ -154,10 +163,10 @@ def profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func): ke.report(SkipLayerNormMetric(func, dtype, duration_ms, total_bytes, batch_size, seq_len, hidden_size)) -def profile_with_args(batch_size, seq_len, hidden_size, dtype, sort=True): +def profile_with_args(batch_size, seq_len, hidden_size, dtype, sort=True, has_optional_output=False): with ke.benchmark(sort): for func in dtype_to_funcs(dtype): - profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func) + profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func, has_optional_output) def profile(): @@ -177,9 +186,12 @@ if __name__ == "__main__": group.add_argument("hidden_size", type=int) group.add_argument("dtype", choices=dtypes) group.add_argument("--sort", action="store_true") + group.add_argument("--has_optional_output", "-o", action="store_true") if len(sys.argv) == 1: profile() else: args = parser.parse_args() - profile_with_args(args.batch_size, args.seq_len, args.hidden_size, args.dtype, args.sort) + profile_with_args( + args.batch_size, args.seq_len, args.hidden_size, args.dtype, args.sort, args.has_optional_output + )