[ROCm] add tunable SkipLayerNorm for ROCm EP (#12817)

**Description**: Describe your changes.
Related PR: https://github.com/microsoft/onnxruntime/pull/12803
https://github.com/microsoft/onnxruntime/pull/12816
https://github.com/microsoft/onnxruntime/pull/12821

1.add tunable skip layernorm for rocm ep
2. keep origin implementation when disable tuning.

**Motivation and Context**
- Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here.
This commit is contained in:
PeixuanZuo 2022-09-23 16:39:44 +08:00 committed by GitHub
parent eafd67b8fd
commit 2ef1f8b93e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 105 additions and 27 deletions

View file

@ -5,6 +5,7 @@
#include "core/providers/rocm/rocm_common.h"
#include "contrib_ops/rocm/bert/skip_layer_norm_impl.h"
#include "contrib_ops/rocm/bert/transformer_common.h"
namespace onnxruntime {
namespace contrib {
@ -30,6 +31,8 @@ template <typename T>
SkipLayerNorm<T>::SkipLayerNorm(const OpKernelInfo& op_kernel_info) : RocmKernel(op_kernel_info) {
ORT_ENFORCE(op_kernel_info.GetAttr<float>("epsilon", &epsilon_).IsOK());
ORT_ENFORCE(epsilon_ >= 0);
const TransformerOptions* options = TransformerOptions::GetInstance();
tuning_ = options->IsTuningEnabled();
}
template <typename T>
@ -106,7 +109,8 @@ Status SkipLayerNorm<T>::ComputeInternal(OpKernelContext* ctx) const {
(bias != nullptr) ? reinterpret_cast<const HipT*>(bias->Data<T>()) : nullptr,
epsilon_,
hidden_size,
static_cast<int>(element_count));
static_cast<int>(element_count),
tuning_);
}
} // namespace rocm

View file

@ -19,6 +19,7 @@ class SkipLayerNorm final : public RocmKernel {
private:
float epsilon_;
bool tuning_;
};
} // namespace rocm

View file

@ -33,6 +33,7 @@ limitations under the License.
#include <hip/hip_fp16.h>
#include "contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h"
#include "contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h"
namespace onnxruntime {
namespace contrib {
@ -41,9 +42,18 @@ namespace rocm {
template <typename T>
Status LaunchSkipLayerNormKernel(
hipStream_t stream, T* output, const T* input, const T* skip, const T* gamma,
const T* beta, const T* bias, float epsilon, const int ld, const int element_count) {
// this must be true because n is the total size of the tensor
const T* beta, const T* bias, float epsilon, int ld, int element_count, bool tuning) {
// this must be true because element_count is the total size of the tensor
assert(element_count % ld == 0);
if (tuning) {
static SkipLayerNormTunableOp<T> op;
op.EnableTuning();
SkipLayerNormParams<T> op_params(stream, output, input, skip, gamma, beta, bias, epsilon, ld, element_count);
return op(&op_params);
}
bool hasBias = (bias == nullptr) ? false : true;
if (0 == (ld % 4)) {
const int grid_size = element_count / ld;
@ -105,13 +115,13 @@ Status LaunchSkipLayerNormKernel(
template Status LaunchSkipLayerNormKernel<float>(hipStream_t stream, float* output, const float* input,
const float* skip, const float* gamma, const float* beta,
const float* bias, float epsilon, const int ld,
const int element_count);
const float* bias, float epsilon, int ld,
int element_count, bool tuning);
template Status LaunchSkipLayerNormKernel<half>(hipStream_t stream, half* output, const half* input,
const half* skip, const half* gamma, const half* beta,
const half* bias, float epsilon, const int ld,
const int element_count);
const half* bias, float epsilon, int ld,
int element_count, bool tuning);
} // namespace rocm
} // namespace contrib

View file

@ -11,16 +11,16 @@ namespace rocm {
template <typename T>
Status LaunchSkipLayerNormKernel(
hipStream_t stream,
T* output, // output tensor
const T* input, // input tensor
const T* skip, // skip tensor
const T* gamma, // Layer normalization gamma tensor
const T* beta, // Layer normalization beta tensor
const T* bias, // Layer normalization beta tensor
float epsilon, // Layer normalization epsilon
int hidden_size, // hidden size, it is the leading dimension (ld)
int element_count // number of elements in input tensor
);
T* output, // output tensor
const T* input, // input tensor
const T* skip, // skip tensor
const T* gamma, // Layer normalization gamma tensor
const T* beta, // Layer normalization beta tensor
const T* bias, // Layer normalization beta tensor
float epsilon, // Layer normalization epsilon
int hidden_size, // hidden size, it is the leading dimension (ld)
int element_count, // number of elements in input tensor
bool tuning);
} // namespace rocm
} // namespace contrib

View file

@ -20,8 +20,7 @@ template <typename T>
struct SkipLayerNormParams : onnxruntime::rocm::tunable::OpParams {
SkipLayerNormParams(hipStream_t stream, T* output, const T* input,
const T* skip, const T* gamma, const T* beta,
const T* bias, float epsilon, const int ld,
const int element_count)
const T* bias, float epsilon, int ld, int element_count)
: OpParams(stream), output(output), input(input), skip(skip), gamma(gamma), beta(beta), bias(bias),
epsilon(epsilon), ld(ld), element_count(element_count) {}
@ -36,9 +35,9 @@ struct SkipLayerNormParams : onnxruntime::rocm::tunable::OpParams {
const T* gamma;
const T* beta;
const T* bias;
const float epsilon;
const int ld;
const int element_count;
float epsilon;
int ld;
int element_count;
};
template <typename T, int ThreadsPerBlock, int VecSize>
@ -55,6 +54,31 @@ Status SkipLayerNormSmallOp(const SkipLayerNormParams<T>* params) {
return HIP_CALL(hipGetLastError());
}
#define ADD_OP(threads_per_block) \
this->ops_.emplace_back(SkipLayerNormSmallOp<T, threads_per_block, 1>); \
this->ops_.emplace_back(SkipLayerNormSmallOp<T, threads_per_block, 2>); \
this->ops_.emplace_back(SkipLayerNormSmallOp<T, threads_per_block, 4>); \
this->ops_.emplace_back(SkipLayerNormSmallOp<T, threads_per_block, 8>); \
this->ops_.emplace_back(SkipLayerNormSmallOp<T, threads_per_block, 16>);
template <typename T>
class SkipLayerNormTunableOp : public onnxruntime::rocm::tunable::TunableOp<SkipLayerNormParams<T>> {
public:
SkipLayerNormTunableOp() {
ADD_OP(64)
ADD_OP(128)
ADD_OP(192)
ADD_OP(256)
ADD_OP(320)
ADD_OP(384)
// NOTE: the 3-th kernel seems to be better in gerenal case, so set it as default one
this->SetDefaultId(3);
}
};
#undef ADD_OP
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime

View file

@ -6,7 +6,7 @@
#include <hip/hip_fp16.h>
#include <pybind11/pybind11.h>
#include "contrib_ops/rocm/bert/skip_layer_norm_op.h"
#include "contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h"
#include "python/tools/kernel_explorer/device_array.h"
#include "python/tools/kernel_explorer/kernel_explorer_interface.h"
@ -38,6 +38,32 @@ class SkipLayerNormSmall : public IKernelExplorer {
ParamsT params_{};
};
template <typename T>
class SkipLayerNormTunable : public IKernelExplorer {
public:
SkipLayerNormTunable(DeviceArray& output, DeviceArray& input, DeviceArray& skip,
DeviceArray& gamma, DeviceArray& beta, DeviceArray& bias,
float epsilon, int hidden_size, int element_count)
: params_(this->Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(input.ptr()),
static_cast<T*>(skip.ptr()), static_cast<T*>(gamma.ptr()), static_cast<T*>(beta.ptr()),
static_cast<T*>(bias.ptr()), epsilon, hidden_size, element_count) {
op_.EnableTuning();
}
void Run() override {
ORT_THROW_IF_ERROR(op_(&params_));
}
bool IsSupported() {
return true;
}
private:
using ParamsT = contrib::rocm::SkipLayerNormParams<T>;
ParamsT params_{};
contrib::rocm::SkipLayerNormTunableOp<T> op_{};
};
#define REGISTER_OP(name, type, threads_per_block, vec_size) \
py::class_<name<type, threads_per_block, vec_size>>(m, #name "_" #type "_" #threads_per_block "_" #vec_size) \
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, \
@ -63,9 +89,22 @@ class SkipLayerNormSmall : public IKernelExplorer {
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 320) \
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 384)
#define REGISTER_TUNABLE_OP(type) \
py::class_<SkipLayerNormTunable<type>>(m, "SkipLayerNorm_" #type "_Tunable") \
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, \
DeviceArray&, DeviceArray&, \
float, int, int>()) \
.def("SetRepeats", &SkipLayerNormTunable<type>::SetRepeats) \
.def("Profile", &SkipLayerNormTunable<type>::Profile) \
.def("Run", &SkipLayerNormTunable<type>::Run) \
.def("IsSupported", &SkipLayerNormTunable<type>::IsSupported);
void InitSkipLayerNorm(py::module m) {
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormSmall, half);
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormSmall, float);
REGISTER_TUNABLE_OP(half);
REGISTER_TUNABLE_OP(float);
}
} // namespace onnxruntime

View file

@ -3,6 +3,7 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import re
from itertools import product
import kernel_explorer as ke
@ -19,8 +20,8 @@ def get_bert_sizes():
def dtype_to_funcs(dtype):
type_map = {
"float16": list(filter(lambda x: "SkipLayerNormSmall_half" in x, dir(ke))),
"float32": list(filter(lambda x: "SkipLayerNormSmall_float" in x, dir(ke))),
"float16": list(filter(lambda x: re.search("SkipLayerNorm.*_half", x), dir(ke))),
"float32": list(filter(lambda x: re.search("SkipLayerNorm.*_float", x), dir(ke))),
}
return type_map[dtype]
@ -109,9 +110,8 @@ def profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func):
def profile():
bert_sizes = get_bert_sizes()
for dtype in dtypes:
for bert_size in bert_sizes:
for bert_size in get_bert_sizes():
for func in dtype_to_funcs(dtype):
profile_skip_layer_norm_func(*bert_size, dtype, func)
print()