From 2ef1f8b93ea7a3717d7e95d51963947364f003f7 Mon Sep 17 00:00:00 2001 From: PeixuanZuo <94887879+PeixuanZuo@users.noreply.github.com> Date: Fri, 23 Sep 2022 16:39:44 +0800 Subject: [PATCH] [ROCm] add tunable SkipLayerNorm for ROCm EP (#12817) **Description**: Describe your changes. Related PR: https://github.com/microsoft/onnxruntime/pull/12803 https://github.com/microsoft/onnxruntime/pull/12816 https://github.com/microsoft/onnxruntime/pull/12821 1.add tunable skip layernorm for rocm ep 2. keep origin implementation when disable tuning. **Motivation and Context** - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --- .../contrib_ops/rocm/bert/skip_layer_norm.cc | 6 ++- .../contrib_ops/rocm/bert/skip_layer_norm.h | 1 + .../rocm/bert/skip_layer_norm_impl.cu | 22 +++++++--- .../rocm/bert/skip_layer_norm_impl.h | 20 ++++----- ...norm_op.h => skip_layer_norm_tunable_op.h} | 34 ++++++++++++--- .../kernels/skip_layer_norm.cc | 41 ++++++++++++++++++- .../kernels/skip_layer_norm_test.py | 8 ++-- 7 files changed, 105 insertions(+), 27 deletions(-) rename onnxruntime/contrib_ops/rocm/bert/{skip_layer_norm_op.h => skip_layer_norm_tunable_op.h} (66%) diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc index b8c6413c37..c1cc1a20e2 100644 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc @@ -5,6 +5,7 @@ #include "core/providers/rocm/rocm_common.h" #include "contrib_ops/rocm/bert/skip_layer_norm_impl.h" +#include "contrib_ops/rocm/bert/transformer_common.h" namespace onnxruntime { namespace contrib { @@ -30,6 +31,8 @@ template SkipLayerNorm::SkipLayerNorm(const OpKernelInfo& op_kernel_info) : RocmKernel(op_kernel_info) { ORT_ENFORCE(op_kernel_info.GetAttr("epsilon", &epsilon_).IsOK()); ORT_ENFORCE(epsilon_ >= 0); + const TransformerOptions* options = TransformerOptions::GetInstance(); + tuning_ = options->IsTuningEnabled(); } template @@ -106,7 +109,8 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const { (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, epsilon_, hidden_size, - static_cast(element_count)); + static_cast(element_count), + tuning_); } } // namespace rocm diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.h index 07d7037227..5e43f193f6 100644 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.h +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.h @@ -19,6 +19,7 @@ class SkipLayerNorm final : public RocmKernel { private: float epsilon_; + bool tuning_; }; } // namespace rocm diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu index 120ccd9162..a4c94236d1 100644 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu @@ -33,6 +33,7 @@ limitations under the License. #include #include "contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h" +#include "contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h" namespace onnxruntime { namespace contrib { @@ -41,9 +42,18 @@ namespace rocm { template Status LaunchSkipLayerNormKernel( hipStream_t stream, T* output, const T* input, const T* skip, const T* gamma, - const T* beta, const T* bias, float epsilon, const int ld, const int element_count) { - // this must be true because n is the total size of the tensor + const T* beta, const T* bias, float epsilon, int ld, int element_count, bool tuning) { + // this must be true because element_count is the total size of the tensor assert(element_count % ld == 0); + + if (tuning) { + static SkipLayerNormTunableOp op; + op.EnableTuning(); + + SkipLayerNormParams op_params(stream, output, input, skip, gamma, beta, bias, epsilon, ld, element_count); + return op(&op_params); + } + bool hasBias = (bias == nullptr) ? false : true; if (0 == (ld % 4)) { const int grid_size = element_count / ld; @@ -105,13 +115,13 @@ Status LaunchSkipLayerNormKernel( template Status LaunchSkipLayerNormKernel(hipStream_t stream, float* output, const float* input, const float* skip, const float* gamma, const float* beta, - const float* bias, float epsilon, const int ld, - const int element_count); + const float* bias, float epsilon, int ld, + int element_count, bool tuning); template Status LaunchSkipLayerNormKernel(hipStream_t stream, half* output, const half* input, const half* skip, const half* gamma, const half* beta, - const half* bias, float epsilon, const int ld, - const int element_count); + const half* bias, float epsilon, int ld, + int element_count, bool tuning); } // namespace rocm } // namespace contrib diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h index 727116a50f..9758988bad 100644 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h @@ -11,16 +11,16 @@ namespace rocm { template Status LaunchSkipLayerNormKernel( hipStream_t stream, - T* output, // output tensor - const T* input, // input tensor - const T* skip, // skip tensor - const T* gamma, // Layer normalization gamma tensor - const T* beta, // Layer normalization beta tensor - const T* bias, // Layer normalization beta tensor - float epsilon, // Layer normalization epsilon - int hidden_size, // hidden size, it is the leading dimension (ld) - int element_count // number of elements in input tensor -); + T* output, // output tensor + const T* input, // input tensor + const T* skip, // skip tensor + const T* gamma, // Layer normalization gamma tensor + const T* beta, // Layer normalization beta tensor + const T* bias, // Layer normalization beta tensor + float epsilon, // Layer normalization epsilon + int hidden_size, // hidden size, it is the leading dimension (ld) + int element_count, // number of elements in input tensor + bool tuning); } // namespace rocm } // namespace contrib diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_op.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h similarity index 66% rename from onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_op.h rename to onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h index 77fede8431..e504c6b78b 100644 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_op.h +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h @@ -20,8 +20,7 @@ template struct SkipLayerNormParams : onnxruntime::rocm::tunable::OpParams { SkipLayerNormParams(hipStream_t stream, T* output, const T* input, const T* skip, const T* gamma, const T* beta, - const T* bias, float epsilon, const int ld, - const int element_count) + const T* bias, float epsilon, int ld, int element_count) : OpParams(stream), output(output), input(input), skip(skip), gamma(gamma), beta(beta), bias(bias), epsilon(epsilon), ld(ld), element_count(element_count) {} @@ -36,9 +35,9 @@ struct SkipLayerNormParams : onnxruntime::rocm::tunable::OpParams { const T* gamma; const T* beta; const T* bias; - const float epsilon; - const int ld; - const int element_count; + float epsilon; + int ld; + int element_count; }; template @@ -55,6 +54,31 @@ Status SkipLayerNormSmallOp(const SkipLayerNormParams* params) { return HIP_CALL(hipGetLastError()); } +#define ADD_OP(threads_per_block) \ + this->ops_.emplace_back(SkipLayerNormSmallOp); \ + this->ops_.emplace_back(SkipLayerNormSmallOp); \ + this->ops_.emplace_back(SkipLayerNormSmallOp); \ + this->ops_.emplace_back(SkipLayerNormSmallOp); \ + this->ops_.emplace_back(SkipLayerNormSmallOp); + +template +class SkipLayerNormTunableOp : public onnxruntime::rocm::tunable::TunableOp> { + public: + SkipLayerNormTunableOp() { + ADD_OP(64) + ADD_OP(128) + ADD_OP(192) + ADD_OP(256) + ADD_OP(320) + ADD_OP(384) + + // NOTE: the 3-th kernel seems to be better in gerenal case, so set it as default one + this->SetDefaultId(3); + } +}; + +#undef ADD_OP + } // namespace rocm } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/skip_layer_norm.cc b/onnxruntime/python/tools/kernel_explorer/kernels/skip_layer_norm.cc index 8fcd92159c..e85ece82ac 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/skip_layer_norm.cc +++ b/onnxruntime/python/tools/kernel_explorer/kernels/skip_layer_norm.cc @@ -6,7 +6,7 @@ #include #include -#include "contrib_ops/rocm/bert/skip_layer_norm_op.h" +#include "contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h" #include "python/tools/kernel_explorer/device_array.h" #include "python/tools/kernel_explorer/kernel_explorer_interface.h" @@ -38,6 +38,32 @@ class SkipLayerNormSmall : public IKernelExplorer { ParamsT params_{}; }; +template +class SkipLayerNormTunable : public IKernelExplorer { + public: + SkipLayerNormTunable(DeviceArray& output, DeviceArray& input, DeviceArray& skip, + DeviceArray& gamma, DeviceArray& beta, DeviceArray& bias, + float epsilon, int hidden_size, int element_count) + : params_(this->Stream(), static_cast(output.ptr()), static_cast(input.ptr()), + static_cast(skip.ptr()), static_cast(gamma.ptr()), static_cast(beta.ptr()), + static_cast(bias.ptr()), epsilon, hidden_size, element_count) { + op_.EnableTuning(); + } + + void Run() override { + ORT_THROW_IF_ERROR(op_(¶ms_)); + } + + bool IsSupported() { + return true; + } + + private: + using ParamsT = contrib::rocm::SkipLayerNormParams; + ParamsT params_{}; + contrib::rocm::SkipLayerNormTunableOp op_{}; +}; + #define REGISTER_OP(name, type, threads_per_block, vec_size) \ py::class_>(m, #name "_" #type "_" #threads_per_block "_" #vec_size) \ .def(py::init>(m, "SkipLayerNorm_" #type "_Tunable") \ + .def(py::init()) \ + .def("SetRepeats", &SkipLayerNormTunable::SetRepeats) \ + .def("Profile", &SkipLayerNormTunable::Profile) \ + .def("Run", &SkipLayerNormTunable::Run) \ + .def("IsSupported", &SkipLayerNormTunable::IsSupported); + void InitSkipLayerNorm(py::module m) { REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormSmall, half); REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormSmall, float); + + REGISTER_TUNABLE_OP(half); + REGISTER_TUNABLE_OP(float); } } // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/skip_layer_norm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/skip_layer_norm_test.py index 5cbadcb73c..e062afcc59 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/skip_layer_norm_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/skip_layer_norm_test.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- +import re from itertools import product import kernel_explorer as ke @@ -19,8 +20,8 @@ def get_bert_sizes(): def dtype_to_funcs(dtype): type_map = { - "float16": list(filter(lambda x: "SkipLayerNormSmall_half" in x, dir(ke))), - "float32": list(filter(lambda x: "SkipLayerNormSmall_float" in x, dir(ke))), + "float16": list(filter(lambda x: re.search("SkipLayerNorm.*_half", x), dir(ke))), + "float32": list(filter(lambda x: re.search("SkipLayerNorm.*_float", x), dir(ke))), } return type_map[dtype] @@ -109,9 +110,8 @@ def profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func): def profile(): - bert_sizes = get_bert_sizes() for dtype in dtypes: - for bert_size in bert_sizes: + for bert_size in get_bert_sizes(): for func in dtype_to_funcs(dtype): profile_skip_layer_norm_func(*bert_size, dtype, func) print()