mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-05 04:17:53 +00:00
[ROCm] add tunable SkipLayerNorm for ROCm EP (#12817)
**Description**: Describe your changes. Related PR: https://github.com/microsoft/onnxruntime/pull/12803 https://github.com/microsoft/onnxruntime/pull/12816 https://github.com/microsoft/onnxruntime/pull/12821 1.add tunable skip layernorm for rocm ep 2. keep origin implementation when disable tuning. **Motivation and Context** - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here.
This commit is contained in:
parent
eafd67b8fd
commit
2ef1f8b93e
7 changed files with 105 additions and 27 deletions
|
|
@ -5,6 +5,7 @@
|
|||
|
||||
#include "core/providers/rocm/rocm_common.h"
|
||||
#include "contrib_ops/rocm/bert/skip_layer_norm_impl.h"
|
||||
#include "contrib_ops/rocm/bert/transformer_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
@ -30,6 +31,8 @@ template <typename T>
|
|||
SkipLayerNorm<T>::SkipLayerNorm(const OpKernelInfo& op_kernel_info) : RocmKernel(op_kernel_info) {
|
||||
ORT_ENFORCE(op_kernel_info.GetAttr<float>("epsilon", &epsilon_).IsOK());
|
||||
ORT_ENFORCE(epsilon_ >= 0);
|
||||
const TransformerOptions* options = TransformerOptions::GetInstance();
|
||||
tuning_ = options->IsTuningEnabled();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -106,7 +109,8 @@ Status SkipLayerNorm<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
(bias != nullptr) ? reinterpret_cast<const HipT*>(bias->Data<T>()) : nullptr,
|
||||
epsilon_,
|
||||
hidden_size,
|
||||
static_cast<int>(element_count));
|
||||
static_cast<int>(element_count),
|
||||
tuning_);
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ class SkipLayerNorm final : public RocmKernel {
|
|||
|
||||
private:
|
||||
float epsilon_;
|
||||
bool tuning_;
|
||||
};
|
||||
|
||||
} // namespace rocm
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ limitations under the License.
|
|||
#include <hip/hip_fp16.h>
|
||||
|
||||
#include "contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h"
|
||||
#include "contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
@ -41,9 +42,18 @@ namespace rocm {
|
|||
template <typename T>
|
||||
Status LaunchSkipLayerNormKernel(
|
||||
hipStream_t stream, T* output, const T* input, const T* skip, const T* gamma,
|
||||
const T* beta, const T* bias, float epsilon, const int ld, const int element_count) {
|
||||
// this must be true because n is the total size of the tensor
|
||||
const T* beta, const T* bias, float epsilon, int ld, int element_count, bool tuning) {
|
||||
// this must be true because element_count is the total size of the tensor
|
||||
assert(element_count % ld == 0);
|
||||
|
||||
if (tuning) {
|
||||
static SkipLayerNormTunableOp<T> op;
|
||||
op.EnableTuning();
|
||||
|
||||
SkipLayerNormParams<T> op_params(stream, output, input, skip, gamma, beta, bias, epsilon, ld, element_count);
|
||||
return op(&op_params);
|
||||
}
|
||||
|
||||
bool hasBias = (bias == nullptr) ? false : true;
|
||||
if (0 == (ld % 4)) {
|
||||
const int grid_size = element_count / ld;
|
||||
|
|
@ -105,13 +115,13 @@ Status LaunchSkipLayerNormKernel(
|
|||
|
||||
template Status LaunchSkipLayerNormKernel<float>(hipStream_t stream, float* output, const float* input,
|
||||
const float* skip, const float* gamma, const float* beta,
|
||||
const float* bias, float epsilon, const int ld,
|
||||
const int element_count);
|
||||
const float* bias, float epsilon, int ld,
|
||||
int element_count, bool tuning);
|
||||
|
||||
template Status LaunchSkipLayerNormKernel<half>(hipStream_t stream, half* output, const half* input,
|
||||
const half* skip, const half* gamma, const half* beta,
|
||||
const half* bias, float epsilon, const int ld,
|
||||
const int element_count);
|
||||
const half* bias, float epsilon, int ld,
|
||||
int element_count, bool tuning);
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace contrib
|
||||
|
|
|
|||
|
|
@ -11,16 +11,16 @@ namespace rocm {
|
|||
template <typename T>
|
||||
Status LaunchSkipLayerNormKernel(
|
||||
hipStream_t stream,
|
||||
T* output, // output tensor
|
||||
const T* input, // input tensor
|
||||
const T* skip, // skip tensor
|
||||
const T* gamma, // Layer normalization gamma tensor
|
||||
const T* beta, // Layer normalization beta tensor
|
||||
const T* bias, // Layer normalization beta tensor
|
||||
float epsilon, // Layer normalization epsilon
|
||||
int hidden_size, // hidden size, it is the leading dimension (ld)
|
||||
int element_count // number of elements in input tensor
|
||||
);
|
||||
T* output, // output tensor
|
||||
const T* input, // input tensor
|
||||
const T* skip, // skip tensor
|
||||
const T* gamma, // Layer normalization gamma tensor
|
||||
const T* beta, // Layer normalization beta tensor
|
||||
const T* bias, // Layer normalization beta tensor
|
||||
float epsilon, // Layer normalization epsilon
|
||||
int hidden_size, // hidden size, it is the leading dimension (ld)
|
||||
int element_count, // number of elements in input tensor
|
||||
bool tuning);
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace contrib
|
||||
|
|
|
|||
|
|
@ -20,8 +20,7 @@ template <typename T>
|
|||
struct SkipLayerNormParams : onnxruntime::rocm::tunable::OpParams {
|
||||
SkipLayerNormParams(hipStream_t stream, T* output, const T* input,
|
||||
const T* skip, const T* gamma, const T* beta,
|
||||
const T* bias, float epsilon, const int ld,
|
||||
const int element_count)
|
||||
const T* bias, float epsilon, int ld, int element_count)
|
||||
: OpParams(stream), output(output), input(input), skip(skip), gamma(gamma), beta(beta), bias(bias),
|
||||
epsilon(epsilon), ld(ld), element_count(element_count) {}
|
||||
|
||||
|
|
@ -36,9 +35,9 @@ struct SkipLayerNormParams : onnxruntime::rocm::tunable::OpParams {
|
|||
const T* gamma;
|
||||
const T* beta;
|
||||
const T* bias;
|
||||
const float epsilon;
|
||||
const int ld;
|
||||
const int element_count;
|
||||
float epsilon;
|
||||
int ld;
|
||||
int element_count;
|
||||
};
|
||||
|
||||
template <typename T, int ThreadsPerBlock, int VecSize>
|
||||
|
|
@ -55,6 +54,31 @@ Status SkipLayerNormSmallOp(const SkipLayerNormParams<T>* params) {
|
|||
return HIP_CALL(hipGetLastError());
|
||||
}
|
||||
|
||||
#define ADD_OP(threads_per_block) \
|
||||
this->ops_.emplace_back(SkipLayerNormSmallOp<T, threads_per_block, 1>); \
|
||||
this->ops_.emplace_back(SkipLayerNormSmallOp<T, threads_per_block, 2>); \
|
||||
this->ops_.emplace_back(SkipLayerNormSmallOp<T, threads_per_block, 4>); \
|
||||
this->ops_.emplace_back(SkipLayerNormSmallOp<T, threads_per_block, 8>); \
|
||||
this->ops_.emplace_back(SkipLayerNormSmallOp<T, threads_per_block, 16>);
|
||||
|
||||
template <typename T>
|
||||
class SkipLayerNormTunableOp : public onnxruntime::rocm::tunable::TunableOp<SkipLayerNormParams<T>> {
|
||||
public:
|
||||
SkipLayerNormTunableOp() {
|
||||
ADD_OP(64)
|
||||
ADD_OP(128)
|
||||
ADD_OP(192)
|
||||
ADD_OP(256)
|
||||
ADD_OP(320)
|
||||
ADD_OP(384)
|
||||
|
||||
// NOTE: the 3-th kernel seems to be better in gerenal case, so set it as default one
|
||||
this->SetDefaultId(3);
|
||||
}
|
||||
};
|
||||
|
||||
#undef ADD_OP
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -6,7 +6,7 @@
|
|||
#include <hip/hip_fp16.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
#include "contrib_ops/rocm/bert/skip_layer_norm_op.h"
|
||||
#include "contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h"
|
||||
#include "python/tools/kernel_explorer/device_array.h"
|
||||
#include "python/tools/kernel_explorer/kernel_explorer_interface.h"
|
||||
|
||||
|
|
@ -38,6 +38,32 @@ class SkipLayerNormSmall : public IKernelExplorer {
|
|||
ParamsT params_{};
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class SkipLayerNormTunable : public IKernelExplorer {
|
||||
public:
|
||||
SkipLayerNormTunable(DeviceArray& output, DeviceArray& input, DeviceArray& skip,
|
||||
DeviceArray& gamma, DeviceArray& beta, DeviceArray& bias,
|
||||
float epsilon, int hidden_size, int element_count)
|
||||
: params_(this->Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(input.ptr()),
|
||||
static_cast<T*>(skip.ptr()), static_cast<T*>(gamma.ptr()), static_cast<T*>(beta.ptr()),
|
||||
static_cast<T*>(bias.ptr()), epsilon, hidden_size, element_count) {
|
||||
op_.EnableTuning();
|
||||
}
|
||||
|
||||
void Run() override {
|
||||
ORT_THROW_IF_ERROR(op_(¶ms_));
|
||||
}
|
||||
|
||||
bool IsSupported() {
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
using ParamsT = contrib::rocm::SkipLayerNormParams<T>;
|
||||
ParamsT params_{};
|
||||
contrib::rocm::SkipLayerNormTunableOp<T> op_{};
|
||||
};
|
||||
|
||||
#define REGISTER_OP(name, type, threads_per_block, vec_size) \
|
||||
py::class_<name<type, threads_per_block, vec_size>>(m, #name "_" #type "_" #threads_per_block "_" #vec_size) \
|
||||
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, \
|
||||
|
|
@ -63,9 +89,22 @@ class SkipLayerNormSmall : public IKernelExplorer {
|
|||
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 320) \
|
||||
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 384)
|
||||
|
||||
#define REGISTER_TUNABLE_OP(type) \
|
||||
py::class_<SkipLayerNormTunable<type>>(m, "SkipLayerNorm_" #type "_Tunable") \
|
||||
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, \
|
||||
DeviceArray&, DeviceArray&, \
|
||||
float, int, int>()) \
|
||||
.def("SetRepeats", &SkipLayerNormTunable<type>::SetRepeats) \
|
||||
.def("Profile", &SkipLayerNormTunable<type>::Profile) \
|
||||
.def("Run", &SkipLayerNormTunable<type>::Run) \
|
||||
.def("IsSupported", &SkipLayerNormTunable<type>::IsSupported);
|
||||
|
||||
void InitSkipLayerNorm(py::module m) {
|
||||
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormSmall, half);
|
||||
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormSmall, float);
|
||||
|
||||
REGISTER_TUNABLE_OP(half);
|
||||
REGISTER_TUNABLE_OP(float);
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import re
|
||||
from itertools import product
|
||||
|
||||
import kernel_explorer as ke
|
||||
|
|
@ -19,8 +20,8 @@ def get_bert_sizes():
|
|||
|
||||
def dtype_to_funcs(dtype):
|
||||
type_map = {
|
||||
"float16": list(filter(lambda x: "SkipLayerNormSmall_half" in x, dir(ke))),
|
||||
"float32": list(filter(lambda x: "SkipLayerNormSmall_float" in x, dir(ke))),
|
||||
"float16": list(filter(lambda x: re.search("SkipLayerNorm.*_half", x), dir(ke))),
|
||||
"float32": list(filter(lambda x: re.search("SkipLayerNorm.*_float", x), dir(ke))),
|
||||
}
|
||||
return type_map[dtype]
|
||||
|
||||
|
|
@ -109,9 +110,8 @@ def profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func):
|
|||
|
||||
|
||||
def profile():
|
||||
bert_sizes = get_bert_sizes()
|
||||
for dtype in dtypes:
|
||||
for bert_size in bert_sizes:
|
||||
for bert_size in get_bert_sizes():
|
||||
for func in dtype_to_funcs(dtype):
|
||||
profile_skip_layer_norm_func(*bert_size, dtype, func)
|
||||
print()
|
||||
|
|
|
|||
Loading…
Reference in a new issue