diff --git a/onnxruntime/contrib_ops/rocm/bert/fast_gelu_tunable_op.h b/onnxruntime/contrib_ops/rocm/bert/fast_gelu_tunable_op.h index 5691d15eaa..bf03ba3aa8 100644 --- a/onnxruntime/contrib_ops/rocm/bert/fast_gelu_tunable_op.h +++ b/onnxruntime/contrib_ops/rocm/bert/fast_gelu_tunable_op.h @@ -11,6 +11,9 @@ #include "core/providers/rocm/cu_inc/common.cuh" #include "contrib_ops/rocm/bert/fast_gelu_impl_kernel.h" +using onnxruntime::rocm::CeilDiv; +using onnxruntime::rocm::GPU_WARP_SIZE; + namespace onnxruntime { namespace contrib { namespace rocm { @@ -33,19 +36,28 @@ struct FastGeluParams : onnxruntime::rocm::tunable::OpParams { }; template -Status FastGeluOp(const FastGeluParams* params) { - // TODO(anyone): Add tail handling for FastGelu - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !((params->bias_length > 0 && params->bias_length % VecSize == 0 && params->input_length % VecSize == 0) || - (params->bias_length == 0 && params->input_length % VecSize == 0))); +class FastGeluOp { + public: + Status operator()(const FastGeluParams* params) { + FastGeluKernelVec + <<input_length, ThreadsPerBlock * VecSize)), + dim3(ThreadsPerBlock), + 0, params->stream>>>( + params->input_length, params->bias_length, params->input, params->bias, params->output); + return HIP_CALL(hipGetLastError()); + } - FastGeluKernelVec - <<input_length, ThreadsPerBlock * VecSize)), - dim3(ThreadsPerBlock), - 0, params->stream>>>( - params->input_length, params->bias_length, params->input, params->bias, params->output); - return HIP_CALL(hipGetLastError()); -} + Status IsSupported(const FastGeluParams* params) { + // TODO(anyone): Add tail handling for FastGelu + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + !((params->bias_length > 0 && params->bias_length % VecSize == 0 && params->input_length % VecSize == 0) || + (params->bias_length == 0 && params->input_length % VecSize == 0))); + // Avoid redundant configurations + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->input_length > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize)); + + return Status::OK(); + } +}; template Status FastGeluStaticSelection(const FastGeluParams* params) { @@ -99,12 +111,12 @@ Status FastGeluStaticSelection(const FastGeluParams* params) { return HIP_CALL(hipGetLastError()); } -#define ADD_OP(threads_per_block) \ - this->ops_.emplace_back(FastGeluOp); \ - this->ops_.emplace_back(FastGeluOp); \ - this->ops_.emplace_back(FastGeluOp); \ - this->ops_.emplace_back(FastGeluOp); \ - this->ops_.emplace_back(FastGeluOp); +#define ADD_OP(threads_per_block) \ + this->ops_.emplace_back(FastGeluOp{}); \ + this->ops_.emplace_back(FastGeluOp{}); \ + this->ops_.emplace_back(FastGeluOp{}); \ + this->ops_.emplace_back(FastGeluOp{}); \ + this->ops_.emplace_back(FastGeluOp{}); template class FastGeluTunableOp : public onnxruntime::rocm::tunable::TunableOp> { diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/fast_gelu_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/fast_gelu_test.py index 525c550578..255e8ebdaa 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/fast_gelu_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/fast_gelu_test.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- +import re import sys from itertools import product @@ -20,9 +21,9 @@ def get_bert_sizes(): def dtype_to_funcs(dtype): type_map = { - "float16": list(filter(lambda x: "FastGelu_half" in x, dir(ke))), - "float32": list(filter(lambda x: "FastGelu_float" in x, dir(ke))), - "float64": list(filter(lambda x: "FastGelu_double" in x, dir(ke))), + "float16": list(filter(lambda x: re.match("FastGelu.*_half.*", x), dir(ke))), + "float32": list(filter(lambda x: re.match("FastGelu.*_float.*", x), dir(ke))), + "float64": list(filter(lambda x: re.match("FastGelu.*_double.*", x), dir(ke))), } return type_map[dtype] @@ -43,15 +44,16 @@ def run_fast_gelu(x_size, bias_size, dtype, func): bias_d = ke.DeviceArray(bias) y_d = ke.DeviceArray(y) f = getattr(ke, func) - va = f(x_d, bias_d, y_d, x.size, bias.size) - va.Run() - y_d.UpdateHostNumpyArray() + my_op = f(x_d, bias_d, y_d, x.size, bias.size) + if my_op.IsSupported(): + my_op.Run() + y_d.UpdateHostNumpyArray() - y_ref = fast_gelu(x, bias) - np.testing.assert_allclose(y_ref, y, rtol=1e-02) + y_ref = fast_gelu(x, bias) + np.testing.assert_allclose(y_ref, y, rtol=1e-02) -test_cases = [((2, 16), 16), ((1, 2, 768), 768), ((1, 2, 1024), 1024)] +test_cases = [((2, 16), 16), ((1, 2, 768), 768), ((1, 2, 1024), 1024), ((1, 3, 3), 3)] dtypes = ["float16", "float32", "float64"] @@ -74,17 +76,19 @@ def profile_fast_gelu_func(batch_size, seq_len, hidden_size, dtype, func): bias_d = ke.DeviceArray(bias) y_d = ke.DeviceArray(y) f = getattr(ke, func) - va = f(x_d, bias_d, y_d, x.size, bias.size) - t = va.Profile() - print( - dtype, - batch_size, - seq_len, - hidden_size, - f, - f"{t*1000:.2f} us", - f"{(x.size*2+bias.size)*x.itemsize*1e3/t/1e9:.2f} GB/s", - ) + my_op = f(x_d, bias_d, y_d, x.size, bias.size) + if my_op.IsSupported(): + t = my_op.Profile() + print( + f"{func:<50} {dtype} batch_size={batch_size:<4} seq_len={seq_len:<4} hidden_size={hidden_size:<4}", + f"{t*1000:.2f} us", + f"{(x.size*2+bias.size)*x.itemsize*1e3/t/1e9:.2f} GB/s", + ) + else: + print( + f"{func:<50} {dtype} batch_size={batch_size:<4} seq_len={seq_len:<4} hidden_size={hidden_size:<4} not supported or redundant" + ) + sys.stdout.flush() def profile_with_args(batch_size, seq_len, hidden_size, dtype): diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/fast_gelu.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/fast_gelu.cu index 19be4713e9..c85b631721 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/fast_gelu.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/fast_gelu.cu @@ -19,8 +19,34 @@ class FastGelu : public IKernelExplorer { : params_(this->Stream(), static_cast(input.ptr()), static_cast(bias.ptr()), static_cast(output.ptr()), input_length, bias_length) {} + bool IsSupported() { + Status status = op_.IsSupported(¶ms_); + return status.IsOK(); + } + void Run() override { - ORT_THROW_IF_ERROR((contrib::rocm::FastGeluOp(¶ms_))); + ORT_THROW_IF_ERROR(op_(¶ms_)); + } + + private: + using ParamsT = contrib::rocm::FastGeluParams; + ParamsT params_{}; + contrib::rocm::FastGeluOp op_{}; +}; + +template +class FastGeluStaticSelection : public IKernelExplorer { + public: + FastGeluStaticSelection(DeviceArray& input, DeviceArray& bias, DeviceArray& output, int input_length, int bias_length) + : params_(this->Stream(), static_cast(input.ptr()), static_cast(bias.ptr()), + static_cast(output.ptr()), input_length, bias_length) {} + + bool IsSupported() { + return true; + } + + void Run() override { + ORT_THROW_IF_ERROR((contrib::rocm::FastGeluStaticSelection(¶ms_))); } private: @@ -41,24 +67,29 @@ class FastGeluTunable : public IKernelExplorer { ORT_THROW_IF_ERROR(op_(¶ms_)); } + bool IsSupported() { + return true; + } + private: using ParamsT = contrib::rocm::FastGeluParams; ParamsT params_{}; contrib::rocm::FastGeluTunableOp op_{}; }; -#define REGISTER_OP(name, type, threads_per_block, vec_size) \ - py::class_>(m, #name"_"#type"_"#threads_per_block"_"#vec_size) \ - .def(py::init()) \ - .def("SetRepeats", &name::SetRepeats) \ - .def("Profile", &name::Profile) \ - .def("Run", &name::Run); +#define REGISTER_OP(name, type, threads_per_block, vec_size) \ + py::class_>(m, #name "_" #type "_" #threads_per_block "_" #vec_size) \ + .def(py::init()) \ + .def("SetRepeats", &name::SetRepeats) \ + .def("Profile", &name::Profile) \ + .def("Run", &name::Run) \ + .def("IsSupported", &name::IsSupported); -#define REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, threads_per_block) \ - REGISTER_OP(name, type, threads_per_block, 1) \ - REGISTER_OP(name, type, threads_per_block, 2) \ - REGISTER_OP(name, type, threads_per_block, 4) \ - REGISTER_OP(name, type, threads_per_block, 8) \ +#define REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, threads_per_block) \ + REGISTER_OP(name, type, threads_per_block, 1) \ + REGISTER_OP(name, type, threads_per_block, 2) \ + REGISTER_OP(name, type, threads_per_block, 4) \ + REGISTER_OP(name, type, threads_per_block, 8) \ REGISTER_OP(name, type, threads_per_block, 16) #define REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK(name, type) \ @@ -71,21 +102,26 @@ class FastGeluTunable : public IKernelExplorer { REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 448) \ REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 512) -#define REGISTER_TUNABLE_OP(type) \ - py::class_>(m, "FastGelu_" #type "_Tunable") \ +#define REGISTER_OP_TYPED(name, type) \ + py::class_>(m, #name "_" #type) \ .def(py::init()) \ - .def("SetRepeats", &FastGeluTunable::SetRepeats) \ - .def("Profile", &FastGeluTunable::Profile) \ - .def("Run", &FastGeluTunable::Run); + .def("SetRepeats", &name::SetRepeats) \ + .def("Profile", &name::Profile) \ + .def("Run", &name::Run) \ + .def("IsSupported", &name::IsSupported); void InitFastGelu(py::module m) { REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK(FastGelu, half); REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK(FastGelu, float); REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK(FastGelu, double); - REGISTER_TUNABLE_OP(half); - REGISTER_TUNABLE_OP(float); - REGISTER_TUNABLE_OP(double); + REGISTER_OP_TYPED(FastGeluTunable, half); + REGISTER_OP_TYPED(FastGeluTunable, float); + REGISTER_OP_TYPED(FastGeluTunable, double); + + REGISTER_OP_TYPED(FastGeluStaticSelection, half); + REGISTER_OP_TYPED(FastGeluStaticSelection, float); + REGISTER_OP_TYPED(FastGeluStaticSelection, double); } } // namespace onnxruntime