[ROCm] Update FastGelu and add kernel expolrer test for FastGeluStaticSelection (#13758)

### Description
<!-- Describe your changes. -->

1. Update FastGelu conditions for supported parameters, avoid redundant
configurations participating in tuning。
2. Add kernel explorer test for FastGeluStaticSelection

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Co-authored-by: peixuanzuo <peixuanzuo@linmif39a000004.zvflicr54joexhdgnhvmxrxygg.phxx.internal.cloudapp.net>
This commit is contained in:
PeixuanZuo 2022-12-08 12:37:10 +08:00 committed by GitHub
parent 7694b695a9
commit c1cc1d5859
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 110 additions and 58 deletions

View file

@ -11,6 +11,9 @@
#include "core/providers/rocm/cu_inc/common.cuh"
#include "contrib_ops/rocm/bert/fast_gelu_impl_kernel.h"
using onnxruntime::rocm::CeilDiv;
using onnxruntime::rocm::GPU_WARP_SIZE;
namespace onnxruntime {
namespace contrib {
namespace rocm {
@ -33,19 +36,28 @@ struct FastGeluParams : onnxruntime::rocm::tunable::OpParams {
};
template <typename T, int ThreadsPerBlock, int VecSize>
Status FastGeluOp(const FastGeluParams<T>* params) {
// TODO(anyone): Add tail handling for FastGelu
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
!((params->bias_length > 0 && params->bias_length % VecSize == 0 && params->input_length % VecSize == 0) ||
(params->bias_length == 0 && params->input_length % VecSize == 0)));
class FastGeluOp {
public:
Status operator()(const FastGeluParams<T>* params) {
FastGeluKernelVec<T, ThreadsPerBlock, VecSize>
<<<dim3(CeilDiv(params->input_length, ThreadsPerBlock * VecSize)),
dim3(ThreadsPerBlock),
0, params->stream>>>(
params->input_length, params->bias_length, params->input, params->bias, params->output);
return HIP_CALL(hipGetLastError());
}
FastGeluKernelVec<T, ThreadsPerBlock, VecSize>
<<<dim3(onnxruntime::rocm::CeilDiv(params->input_length, ThreadsPerBlock * VecSize)),
dim3(ThreadsPerBlock),
0, params->stream>>>(
params->input_length, params->bias_length, params->input, params->bias, params->output);
return HIP_CALL(hipGetLastError());
}
Status IsSupported(const FastGeluParams<T>* params) {
// TODO(anyone): Add tail handling for FastGelu
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
!((params->bias_length > 0 && params->bias_length % VecSize == 0 && params->input_length % VecSize == 0) ||
(params->bias_length == 0 && params->input_length % VecSize == 0)));
// Avoid redundant configurations
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->input_length > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize));
return Status::OK();
}
};
template <typename T>
Status FastGeluStaticSelection(const FastGeluParams<T>* params) {
@ -99,12 +111,12 @@ Status FastGeluStaticSelection(const FastGeluParams<half>* params) {
return HIP_CALL(hipGetLastError());
}
#define ADD_OP(threads_per_block) \
this->ops_.emplace_back(FastGeluOp<T, threads_per_block, 1>); \
this->ops_.emplace_back(FastGeluOp<T, threads_per_block, 2>); \
this->ops_.emplace_back(FastGeluOp<T, threads_per_block, 4>); \
this->ops_.emplace_back(FastGeluOp<T, threads_per_block, 8>); \
this->ops_.emplace_back(FastGeluOp<T, threads_per_block, 16>);
#define ADD_OP(threads_per_block) \
this->ops_.emplace_back(FastGeluOp<T, threads_per_block, 1>{}); \
this->ops_.emplace_back(FastGeluOp<T, threads_per_block, 2>{}); \
this->ops_.emplace_back(FastGeluOp<T, threads_per_block, 4>{}); \
this->ops_.emplace_back(FastGeluOp<T, threads_per_block, 8>{}); \
this->ops_.emplace_back(FastGeluOp<T, threads_per_block, 16>{});
template <typename T>
class FastGeluTunableOp : public onnxruntime::rocm::tunable::TunableOp<FastGeluParams<T>> {

View file

@ -3,6 +3,7 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import re
import sys
from itertools import product
@ -20,9 +21,9 @@ def get_bert_sizes():
def dtype_to_funcs(dtype):
type_map = {
"float16": list(filter(lambda x: "FastGelu_half" in x, dir(ke))),
"float32": list(filter(lambda x: "FastGelu_float" in x, dir(ke))),
"float64": list(filter(lambda x: "FastGelu_double" in x, dir(ke))),
"float16": list(filter(lambda x: re.match("FastGelu.*_half.*", x), dir(ke))),
"float32": list(filter(lambda x: re.match("FastGelu.*_float.*", x), dir(ke))),
"float64": list(filter(lambda x: re.match("FastGelu.*_double.*", x), dir(ke))),
}
return type_map[dtype]
@ -43,15 +44,16 @@ def run_fast_gelu(x_size, bias_size, dtype, func):
bias_d = ke.DeviceArray(bias)
y_d = ke.DeviceArray(y)
f = getattr(ke, func)
va = f(x_d, bias_d, y_d, x.size, bias.size)
va.Run()
y_d.UpdateHostNumpyArray()
my_op = f(x_d, bias_d, y_d, x.size, bias.size)
if my_op.IsSupported():
my_op.Run()
y_d.UpdateHostNumpyArray()
y_ref = fast_gelu(x, bias)
np.testing.assert_allclose(y_ref, y, rtol=1e-02)
y_ref = fast_gelu(x, bias)
np.testing.assert_allclose(y_ref, y, rtol=1e-02)
test_cases = [((2, 16), 16), ((1, 2, 768), 768), ((1, 2, 1024), 1024)]
test_cases = [((2, 16), 16), ((1, 2, 768), 768), ((1, 2, 1024), 1024), ((1, 3, 3), 3)]
dtypes = ["float16", "float32", "float64"]
@ -74,17 +76,19 @@ def profile_fast_gelu_func(batch_size, seq_len, hidden_size, dtype, func):
bias_d = ke.DeviceArray(bias)
y_d = ke.DeviceArray(y)
f = getattr(ke, func)
va = f(x_d, bias_d, y_d, x.size, bias.size)
t = va.Profile()
print(
dtype,
batch_size,
seq_len,
hidden_size,
f,
f"{t*1000:.2f} us",
f"{(x.size*2+bias.size)*x.itemsize*1e3/t/1e9:.2f} GB/s",
)
my_op = f(x_d, bias_d, y_d, x.size, bias.size)
if my_op.IsSupported():
t = my_op.Profile()
print(
f"{func:<50} {dtype} batch_size={batch_size:<4} seq_len={seq_len:<4} hidden_size={hidden_size:<4}",
f"{t*1000:.2f} us",
f"{(x.size*2+bias.size)*x.itemsize*1e3/t/1e9:.2f} GB/s",
)
else:
print(
f"{func:<50} {dtype} batch_size={batch_size:<4} seq_len={seq_len:<4} hidden_size={hidden_size:<4} not supported or redundant"
)
sys.stdout.flush()
def profile_with_args(batch_size, seq_len, hidden_size, dtype):

View file

@ -19,8 +19,34 @@ class FastGelu : public IKernelExplorer {
: params_(this->Stream(), static_cast<T*>(input.ptr()), static_cast<T*>(bias.ptr()),
static_cast<T*>(output.ptr()), input_length, bias_length) {}
bool IsSupported() {
Status status = op_.IsSupported(&params_);
return status.IsOK();
}
void Run() override {
ORT_THROW_IF_ERROR((contrib::rocm::FastGeluOp<T, ThreadsPerBlock, VecSize>(&params_)));
ORT_THROW_IF_ERROR(op_(&params_));
}
private:
using ParamsT = contrib::rocm::FastGeluParams<T>;
ParamsT params_{};
contrib::rocm::FastGeluOp<T, ThreadsPerBlock, VecSize> op_{};
};
template <typename T>
class FastGeluStaticSelection : public IKernelExplorer {
public:
FastGeluStaticSelection(DeviceArray& input, DeviceArray& bias, DeviceArray& output, int input_length, int bias_length)
: params_(this->Stream(), static_cast<T*>(input.ptr()), static_cast<T*>(bias.ptr()),
static_cast<T*>(output.ptr()), input_length, bias_length) {}
bool IsSupported() {
return true;
}
void Run() override {
ORT_THROW_IF_ERROR((contrib::rocm::FastGeluStaticSelection<T>(&params_)));
}
private:
@ -41,24 +67,29 @@ class FastGeluTunable : public IKernelExplorer {
ORT_THROW_IF_ERROR(op_(&params_));
}
bool IsSupported() {
return true;
}
private:
using ParamsT = contrib::rocm::FastGeluParams<T>;
ParamsT params_{};
contrib::rocm::FastGeluTunableOp<T> op_{};
};
#define REGISTER_OP(name, type, threads_per_block, vec_size) \
py::class_<name<type, threads_per_block, vec_size>>(m, #name"_"#type"_"#threads_per_block"_"#vec_size) \
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, int, int>()) \
.def("SetRepeats", &name<type, threads_per_block, vec_size>::SetRepeats) \
.def("Profile", &name<type, threads_per_block, vec_size>::Profile) \
.def("Run", &name<type, threads_per_block, vec_size>::Run);
#define REGISTER_OP(name, type, threads_per_block, vec_size) \
py::class_<name<type, threads_per_block, vec_size>>(m, #name "_" #type "_" #threads_per_block "_" #vec_size) \
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, int, int>()) \
.def("SetRepeats", &name<type, threads_per_block, vec_size>::SetRepeats) \
.def("Profile", &name<type, threads_per_block, vec_size>::Profile) \
.def("Run", &name<type, threads_per_block, vec_size>::Run) \
.def("IsSupported", &name<type, threads_per_block, vec_size>::IsSupported);
#define REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, threads_per_block) \
REGISTER_OP(name, type, threads_per_block, 1) \
REGISTER_OP(name, type, threads_per_block, 2) \
REGISTER_OP(name, type, threads_per_block, 4) \
REGISTER_OP(name, type, threads_per_block, 8) \
#define REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, threads_per_block) \
REGISTER_OP(name, type, threads_per_block, 1) \
REGISTER_OP(name, type, threads_per_block, 2) \
REGISTER_OP(name, type, threads_per_block, 4) \
REGISTER_OP(name, type, threads_per_block, 8) \
REGISTER_OP(name, type, threads_per_block, 16)
#define REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK(name, type) \
@ -71,21 +102,26 @@ class FastGeluTunable : public IKernelExplorer {
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 448) \
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 512)
#define REGISTER_TUNABLE_OP(type) \
py::class_<FastGeluTunable<type>>(m, "FastGelu_" #type "_Tunable") \
#define REGISTER_OP_TYPED(name, type) \
py::class_<name<type>>(m, #name "_" #type) \
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, int, int>()) \
.def("SetRepeats", &FastGeluTunable<type>::SetRepeats) \
.def("Profile", &FastGeluTunable<type>::Profile) \
.def("Run", &FastGeluTunable<type>::Run);
.def("SetRepeats", &name<type>::SetRepeats) \
.def("Profile", &name<type>::Profile) \
.def("Run", &name<type>::Run) \
.def("IsSupported", &name<type>::IsSupported);
void InitFastGelu(py::module m) {
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK(FastGelu, half);
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK(FastGelu, float);
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK(FastGelu, double);
REGISTER_TUNABLE_OP(half);
REGISTER_TUNABLE_OP(float);
REGISTER_TUNABLE_OP(double);
REGISTER_OP_TYPED(FastGeluTunable, half);
REGISTER_OP_TYPED(FastGeluTunable, float);
REGISTER_OP_TYPED(FastGeluTunable, double);
REGISTER_OP_TYPED(FastGeluStaticSelection, half);
REGISTER_OP_TYPED(FastGeluStaticSelection, float);
REGISTER_OP_TYPED(FastGeluStaticSelection, double);
}
} // namespace onnxruntime