mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-08 00:23:03 +00:00
[ROCm] Update FastGelu and add kernel expolrer test for FastGeluStaticSelection (#13758)
### Description <!-- Describe your changes. --> 1. Update FastGelu conditions for supported parameters, avoid redundant configurations participating in tuning。 2. Add kernel explorer test for FastGeluStaticSelection ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Co-authored-by: peixuanzuo <peixuanzuo@linmif39a000004.zvflicr54joexhdgnhvmxrxygg.phxx.internal.cloudapp.net>
This commit is contained in:
parent
7694b695a9
commit
c1cc1d5859
3 changed files with 110 additions and 58 deletions
|
|
@ -11,6 +11,9 @@
|
|||
#include "core/providers/rocm/cu_inc/common.cuh"
|
||||
#include "contrib_ops/rocm/bert/fast_gelu_impl_kernel.h"
|
||||
|
||||
using onnxruntime::rocm::CeilDiv;
|
||||
using onnxruntime::rocm::GPU_WARP_SIZE;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace rocm {
|
||||
|
|
@ -33,19 +36,28 @@ struct FastGeluParams : onnxruntime::rocm::tunable::OpParams {
|
|||
};
|
||||
|
||||
template <typename T, int ThreadsPerBlock, int VecSize>
|
||||
Status FastGeluOp(const FastGeluParams<T>* params) {
|
||||
// TODO(anyone): Add tail handling for FastGelu
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
|
||||
!((params->bias_length > 0 && params->bias_length % VecSize == 0 && params->input_length % VecSize == 0) ||
|
||||
(params->bias_length == 0 && params->input_length % VecSize == 0)));
|
||||
class FastGeluOp {
|
||||
public:
|
||||
Status operator()(const FastGeluParams<T>* params) {
|
||||
FastGeluKernelVec<T, ThreadsPerBlock, VecSize>
|
||||
<<<dim3(CeilDiv(params->input_length, ThreadsPerBlock * VecSize)),
|
||||
dim3(ThreadsPerBlock),
|
||||
0, params->stream>>>(
|
||||
params->input_length, params->bias_length, params->input, params->bias, params->output);
|
||||
return HIP_CALL(hipGetLastError());
|
||||
}
|
||||
|
||||
FastGeluKernelVec<T, ThreadsPerBlock, VecSize>
|
||||
<<<dim3(onnxruntime::rocm::CeilDiv(params->input_length, ThreadsPerBlock * VecSize)),
|
||||
dim3(ThreadsPerBlock),
|
||||
0, params->stream>>>(
|
||||
params->input_length, params->bias_length, params->input, params->bias, params->output);
|
||||
return HIP_CALL(hipGetLastError());
|
||||
}
|
||||
Status IsSupported(const FastGeluParams<T>* params) {
|
||||
// TODO(anyone): Add tail handling for FastGelu
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
|
||||
!((params->bias_length > 0 && params->bias_length % VecSize == 0 && params->input_length % VecSize == 0) ||
|
||||
(params->bias_length == 0 && params->input_length % VecSize == 0)));
|
||||
// Avoid redundant configurations
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->input_length > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
Status FastGeluStaticSelection(const FastGeluParams<T>* params) {
|
||||
|
|
@ -99,12 +111,12 @@ Status FastGeluStaticSelection(const FastGeluParams<half>* params) {
|
|||
return HIP_CALL(hipGetLastError());
|
||||
}
|
||||
|
||||
#define ADD_OP(threads_per_block) \
|
||||
this->ops_.emplace_back(FastGeluOp<T, threads_per_block, 1>); \
|
||||
this->ops_.emplace_back(FastGeluOp<T, threads_per_block, 2>); \
|
||||
this->ops_.emplace_back(FastGeluOp<T, threads_per_block, 4>); \
|
||||
this->ops_.emplace_back(FastGeluOp<T, threads_per_block, 8>); \
|
||||
this->ops_.emplace_back(FastGeluOp<T, threads_per_block, 16>);
|
||||
#define ADD_OP(threads_per_block) \
|
||||
this->ops_.emplace_back(FastGeluOp<T, threads_per_block, 1>{}); \
|
||||
this->ops_.emplace_back(FastGeluOp<T, threads_per_block, 2>{}); \
|
||||
this->ops_.emplace_back(FastGeluOp<T, threads_per_block, 4>{}); \
|
||||
this->ops_.emplace_back(FastGeluOp<T, threads_per_block, 8>{}); \
|
||||
this->ops_.emplace_back(FastGeluOp<T, threads_per_block, 16>{});
|
||||
|
||||
template <typename T>
|
||||
class FastGeluTunableOp : public onnxruntime::rocm::tunable::TunableOp<FastGeluParams<T>> {
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import re
|
||||
import sys
|
||||
from itertools import product
|
||||
|
||||
|
|
@ -20,9 +21,9 @@ def get_bert_sizes():
|
|||
|
||||
def dtype_to_funcs(dtype):
|
||||
type_map = {
|
||||
"float16": list(filter(lambda x: "FastGelu_half" in x, dir(ke))),
|
||||
"float32": list(filter(lambda x: "FastGelu_float" in x, dir(ke))),
|
||||
"float64": list(filter(lambda x: "FastGelu_double" in x, dir(ke))),
|
||||
"float16": list(filter(lambda x: re.match("FastGelu.*_half.*", x), dir(ke))),
|
||||
"float32": list(filter(lambda x: re.match("FastGelu.*_float.*", x), dir(ke))),
|
||||
"float64": list(filter(lambda x: re.match("FastGelu.*_double.*", x), dir(ke))),
|
||||
}
|
||||
return type_map[dtype]
|
||||
|
||||
|
|
@ -43,15 +44,16 @@ def run_fast_gelu(x_size, bias_size, dtype, func):
|
|||
bias_d = ke.DeviceArray(bias)
|
||||
y_d = ke.DeviceArray(y)
|
||||
f = getattr(ke, func)
|
||||
va = f(x_d, bias_d, y_d, x.size, bias.size)
|
||||
va.Run()
|
||||
y_d.UpdateHostNumpyArray()
|
||||
my_op = f(x_d, bias_d, y_d, x.size, bias.size)
|
||||
if my_op.IsSupported():
|
||||
my_op.Run()
|
||||
y_d.UpdateHostNumpyArray()
|
||||
|
||||
y_ref = fast_gelu(x, bias)
|
||||
np.testing.assert_allclose(y_ref, y, rtol=1e-02)
|
||||
y_ref = fast_gelu(x, bias)
|
||||
np.testing.assert_allclose(y_ref, y, rtol=1e-02)
|
||||
|
||||
|
||||
test_cases = [((2, 16), 16), ((1, 2, 768), 768), ((1, 2, 1024), 1024)]
|
||||
test_cases = [((2, 16), 16), ((1, 2, 768), 768), ((1, 2, 1024), 1024), ((1, 3, 3), 3)]
|
||||
dtypes = ["float16", "float32", "float64"]
|
||||
|
||||
|
||||
|
|
@ -74,17 +76,19 @@ def profile_fast_gelu_func(batch_size, seq_len, hidden_size, dtype, func):
|
|||
bias_d = ke.DeviceArray(bias)
|
||||
y_d = ke.DeviceArray(y)
|
||||
f = getattr(ke, func)
|
||||
va = f(x_d, bias_d, y_d, x.size, bias.size)
|
||||
t = va.Profile()
|
||||
print(
|
||||
dtype,
|
||||
batch_size,
|
||||
seq_len,
|
||||
hidden_size,
|
||||
f,
|
||||
f"{t*1000:.2f} us",
|
||||
f"{(x.size*2+bias.size)*x.itemsize*1e3/t/1e9:.2f} GB/s",
|
||||
)
|
||||
my_op = f(x_d, bias_d, y_d, x.size, bias.size)
|
||||
if my_op.IsSupported():
|
||||
t = my_op.Profile()
|
||||
print(
|
||||
f"{func:<50} {dtype} batch_size={batch_size:<4} seq_len={seq_len:<4} hidden_size={hidden_size:<4}",
|
||||
f"{t*1000:.2f} us",
|
||||
f"{(x.size*2+bias.size)*x.itemsize*1e3/t/1e9:.2f} GB/s",
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"{func:<50} {dtype} batch_size={batch_size:<4} seq_len={seq_len:<4} hidden_size={hidden_size:<4} not supported or redundant"
|
||||
)
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def profile_with_args(batch_size, seq_len, hidden_size, dtype):
|
||||
|
|
|
|||
|
|
@ -19,8 +19,34 @@ class FastGelu : public IKernelExplorer {
|
|||
: params_(this->Stream(), static_cast<T*>(input.ptr()), static_cast<T*>(bias.ptr()),
|
||||
static_cast<T*>(output.ptr()), input_length, bias_length) {}
|
||||
|
||||
bool IsSupported() {
|
||||
Status status = op_.IsSupported(¶ms_);
|
||||
return status.IsOK();
|
||||
}
|
||||
|
||||
void Run() override {
|
||||
ORT_THROW_IF_ERROR((contrib::rocm::FastGeluOp<T, ThreadsPerBlock, VecSize>(¶ms_)));
|
||||
ORT_THROW_IF_ERROR(op_(¶ms_));
|
||||
}
|
||||
|
||||
private:
|
||||
using ParamsT = contrib::rocm::FastGeluParams<T>;
|
||||
ParamsT params_{};
|
||||
contrib::rocm::FastGeluOp<T, ThreadsPerBlock, VecSize> op_{};
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class FastGeluStaticSelection : public IKernelExplorer {
|
||||
public:
|
||||
FastGeluStaticSelection(DeviceArray& input, DeviceArray& bias, DeviceArray& output, int input_length, int bias_length)
|
||||
: params_(this->Stream(), static_cast<T*>(input.ptr()), static_cast<T*>(bias.ptr()),
|
||||
static_cast<T*>(output.ptr()), input_length, bias_length) {}
|
||||
|
||||
bool IsSupported() {
|
||||
return true;
|
||||
}
|
||||
|
||||
void Run() override {
|
||||
ORT_THROW_IF_ERROR((contrib::rocm::FastGeluStaticSelection<T>(¶ms_)));
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
@ -41,24 +67,29 @@ class FastGeluTunable : public IKernelExplorer {
|
|||
ORT_THROW_IF_ERROR(op_(¶ms_));
|
||||
}
|
||||
|
||||
bool IsSupported() {
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
using ParamsT = contrib::rocm::FastGeluParams<T>;
|
||||
ParamsT params_{};
|
||||
contrib::rocm::FastGeluTunableOp<T> op_{};
|
||||
};
|
||||
|
||||
#define REGISTER_OP(name, type, threads_per_block, vec_size) \
|
||||
py::class_<name<type, threads_per_block, vec_size>>(m, #name"_"#type"_"#threads_per_block"_"#vec_size) \
|
||||
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, int, int>()) \
|
||||
.def("SetRepeats", &name<type, threads_per_block, vec_size>::SetRepeats) \
|
||||
.def("Profile", &name<type, threads_per_block, vec_size>::Profile) \
|
||||
.def("Run", &name<type, threads_per_block, vec_size>::Run);
|
||||
#define REGISTER_OP(name, type, threads_per_block, vec_size) \
|
||||
py::class_<name<type, threads_per_block, vec_size>>(m, #name "_" #type "_" #threads_per_block "_" #vec_size) \
|
||||
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, int, int>()) \
|
||||
.def("SetRepeats", &name<type, threads_per_block, vec_size>::SetRepeats) \
|
||||
.def("Profile", &name<type, threads_per_block, vec_size>::Profile) \
|
||||
.def("Run", &name<type, threads_per_block, vec_size>::Run) \
|
||||
.def("IsSupported", &name<type, threads_per_block, vec_size>::IsSupported);
|
||||
|
||||
#define REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, threads_per_block) \
|
||||
REGISTER_OP(name, type, threads_per_block, 1) \
|
||||
REGISTER_OP(name, type, threads_per_block, 2) \
|
||||
REGISTER_OP(name, type, threads_per_block, 4) \
|
||||
REGISTER_OP(name, type, threads_per_block, 8) \
|
||||
#define REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, threads_per_block) \
|
||||
REGISTER_OP(name, type, threads_per_block, 1) \
|
||||
REGISTER_OP(name, type, threads_per_block, 2) \
|
||||
REGISTER_OP(name, type, threads_per_block, 4) \
|
||||
REGISTER_OP(name, type, threads_per_block, 8) \
|
||||
REGISTER_OP(name, type, threads_per_block, 16)
|
||||
|
||||
#define REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK(name, type) \
|
||||
|
|
@ -71,21 +102,26 @@ class FastGeluTunable : public IKernelExplorer {
|
|||
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 448) \
|
||||
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 512)
|
||||
|
||||
#define REGISTER_TUNABLE_OP(type) \
|
||||
py::class_<FastGeluTunable<type>>(m, "FastGelu_" #type "_Tunable") \
|
||||
#define REGISTER_OP_TYPED(name, type) \
|
||||
py::class_<name<type>>(m, #name "_" #type) \
|
||||
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, int, int>()) \
|
||||
.def("SetRepeats", &FastGeluTunable<type>::SetRepeats) \
|
||||
.def("Profile", &FastGeluTunable<type>::Profile) \
|
||||
.def("Run", &FastGeluTunable<type>::Run);
|
||||
.def("SetRepeats", &name<type>::SetRepeats) \
|
||||
.def("Profile", &name<type>::Profile) \
|
||||
.def("Run", &name<type>::Run) \
|
||||
.def("IsSupported", &name<type>::IsSupported);
|
||||
|
||||
void InitFastGelu(py::module m) {
|
||||
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK(FastGelu, half);
|
||||
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK(FastGelu, float);
|
||||
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK(FastGelu, double);
|
||||
|
||||
REGISTER_TUNABLE_OP(half);
|
||||
REGISTER_TUNABLE_OP(float);
|
||||
REGISTER_TUNABLE_OP(double);
|
||||
REGISTER_OP_TYPED(FastGeluTunable, half);
|
||||
REGISTER_OP_TYPED(FastGeluTunable, float);
|
||||
REGISTER_OP_TYPED(FastGeluTunable, double);
|
||||
|
||||
REGISTER_OP_TYPED(FastGeluStaticSelection, half);
|
||||
REGISTER_OP_TYPED(FastGeluStaticSelection, float);
|
||||
REGISTER_OP_TYPED(FastGeluStaticSelection, double);
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue