mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-07 00:13:17 +00:00
[ROCm] add SimplifiedSkipLayerNorm implementation (#17213)
add SimplifiedSkipLayerNorm implementation
This commit is contained in:
parent
4e6cec4d09
commit
d5c565156d
12 changed files with 249 additions and 84 deletions
|
|
@ -109,6 +109,61 @@ __device__ inline void LayerNorm(
|
|||
}
|
||||
}
|
||||
|
||||
template <typename U, typename V, int TPB>
|
||||
__device__ inline void SimplifiedLayerNorm(
|
||||
const U& thread_data, const int ld, const int offset, const V* gamma, const U epsilon, V* output) {
|
||||
// Assuming thread_data is already divided by ld
|
||||
|
||||
using BlockReduce = hipcub::BlockReduce<U, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
__shared__ U rsigma; // 1 / std.dev.
|
||||
|
||||
const U sum = BlockReduce(temp_storage).Sum(thread_data);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
rsigma = Rsqrt(sum + epsilon);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int i = threadIdx.x; i < ld; i += TPB) {
|
||||
const int idx = offset + i;
|
||||
const U val = static_cast<U>(output[idx]);
|
||||
const U g = static_cast<U>(gamma[i]);
|
||||
output[idx] = static_cast<V>(g * val * rsigma);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, typename V, int TPB, int ILP>
|
||||
__device__ inline void SimplifiedLayerNormVec(
|
||||
const U& thread_data, const int ld, const int offset, const V* gamma, const U epsilon, V* output) {
|
||||
// Assuming thread_data is already divided by ld
|
||||
using VecV = aligned_vector<V, ILP>;
|
||||
using BlockReduce = hipcub::BlockReduce<U, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
__shared__ U rsigma; // 1 / std.dev.
|
||||
|
||||
const U sum = BlockReduce(temp_storage).Sum(thread_data);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
rsigma = Rsqrt(sum + epsilon);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (ILP * threadIdx.x < ld) {
|
||||
for (int i = threadIdx.x * ILP; i < ld; i += TPB * ILP) {
|
||||
int idx = offset + i;
|
||||
const VecV gamma_v = *reinterpret_cast<const VecV*>(gamma + i);
|
||||
VecV output_v = *reinterpret_cast<const VecV*>(output + idx);
|
||||
|
||||
#pragma unroll
|
||||
for (int k = 0; k < ILP; k++) {
|
||||
output_v.val[k] = U(gamma_v.val[k]) * U(output_v.val[k]) * rsigma;
|
||||
}
|
||||
*(reinterpret_cast<VecV*>(output + idx)) = output_v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, typename V, int TPB, int ILP>
|
||||
__device__ inline void LayerNormVec(
|
||||
const hipcub::KeyValuePair<U, U>& thread_data, const int ld, const int offset, const V* beta,
|
||||
|
|
@ -182,6 +237,36 @@ __device__ inline void LayerNormSmall(const T* input_v, const hipcub::KeyValuePa
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename V, int TPB, int ILP>
|
||||
__device__ inline void SimplifiedLayerNormSmall(const T* input_v, const U& thread_data, const int ld, const int idx,
|
||||
const V* gamma, const U epsilon, V* output) {
|
||||
// Assuming thread_data is already divided by ld
|
||||
// Small settings: the block covers the leading dimension TPB >= ld. The input
|
||||
// value is available in a register
|
||||
using VecV = aligned_vector<V, ILP>;
|
||||
using BlockReduce = hipcub::BlockReduce<U, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
__shared__ U rsigma; // 1 / std.dev.
|
||||
|
||||
const U sum = BlockReduce(temp_storage).Sum(thread_data);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
rsigma = Rsqrt(sum + epsilon);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (ILP * threadIdx.x < ld) {
|
||||
const VecV gamma_v = *reinterpret_cast<const VecV*>(gamma + threadIdx.x * ILP);
|
||||
VecV output_v;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ILP; i++) {
|
||||
output_v.val[i] = U(gamma_v.val[i]) * U(input_v[i]) * rsigma;
|
||||
}
|
||||
*(reinterpret_cast<VecV*>(output + idx)) = output_v;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -20,26 +20,36 @@ namespace rocm {
|
|||
kRocmExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()) \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
SkipLayerNorm<T>);
|
||||
SkipLayerNorm<T, false>); \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
SkipSimplifiedLayerNormalization, \
|
||||
kMSDomain, \
|
||||
1, \
|
||||
T, \
|
||||
kRocmExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()) \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
SkipLayerNorm<T, true>);
|
||||
|
||||
REGISTER_KERNEL_TYPED(float)
|
||||
REGISTER_KERNEL_TYPED(MLFloat16)
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
|
||||
template <typename T>
|
||||
SkipLayerNorm<T>::SkipLayerNorm(const OpKernelInfo& op_kernel_info) : RocmKernel(op_kernel_info) {
|
||||
template <typename T, bool Simplified>
|
||||
SkipLayerNorm<T, Simplified>::SkipLayerNorm(const OpKernelInfo& op_kernel_info) : RocmKernel(op_kernel_info) {
|
||||
ORT_ENFORCE(op_kernel_info.GetAttr<float>("epsilon", &epsilon_).IsOK());
|
||||
ORT_ENFORCE(epsilon_ >= 0);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status SkipLayerNorm<T>::ComputeInternal(OpKernelContext* ctx) const {
|
||||
template <typename T, bool Simplified>
|
||||
Status SkipLayerNorm<T, Simplified>::ComputeInternal(OpKernelContext* ctx) const {
|
||||
const Tensor* input = ctx->Input<Tensor>(0);
|
||||
const Tensor* skip = ctx->Input<Tensor>(1);
|
||||
const Tensor* gamma = ctx->Input<Tensor>(2);
|
||||
const Tensor* beta = ctx->Input<Tensor>(3);
|
||||
const Tensor* bias = ctx->Input<Tensor>(4);
|
||||
|
||||
const Tensor* beta = Simplified ? nullptr : ctx->Input<Tensor>(3);
|
||||
const Tensor* bias = Simplified ? ctx->Input<Tensor>(3) : ctx->Input<Tensor>(4);
|
||||
|
||||
Tensor* output = ctx->Output(0, input->Shape());
|
||||
|
||||
|
|
@ -102,7 +112,7 @@ Status SkipLayerNorm<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
int64_t element_count = input->Shape().Size();
|
||||
typedef typename ToHipType<T>::MappedType HipT;
|
||||
|
||||
return LaunchSkipLayerNormKernel<HipT, float, HipT>(
|
||||
return LaunchSkipLayerNormKernel<HipT, float, HipT, Simplified>(
|
||||
GetTuningContext(),
|
||||
ctx->GetComputeStream(),
|
||||
reinterpret_cast<HipT*>(output->MutableData<T>()),
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ namespace rocm {
|
|||
|
||||
using namespace onnxruntime::rocm;
|
||||
|
||||
template <typename T>
|
||||
template <typename T, bool Simplified>
|
||||
class SkipLayerNorm final : public RocmKernel {
|
||||
public:
|
||||
SkipLayerNorm(const OpKernelInfo& op_kernel_info);
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ namespace onnxruntime {
|
|||
namespace contrib {
|
||||
namespace rocm {
|
||||
|
||||
template <typename T, typename U, typename V>
|
||||
template <typename T, typename U, typename V, bool Simplified>
|
||||
Status LaunchSkipLayerNormKernel(
|
||||
RocmTuningContext* tuning_ctx, Stream* stream, V* output, T* skip_input_bias_add_output, const T* input,
|
||||
const T* skip, const V* gamma, const V* beta, const T* bias, float epsilon, int ld, int element_count) {
|
||||
|
|
@ -50,20 +50,32 @@ Status LaunchSkipLayerNormKernel(
|
|||
gamma, beta, bias, epsilon, ld, element_count);
|
||||
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
static SkipLayerNormTunableOp<T, U, V> op;
|
||||
static SkipLayerNormTunableOp<T, U, V, Simplified> op;
|
||||
return op(¶ms);
|
||||
}
|
||||
|
||||
return SkipLayerNormStaticSelection<T, U, V>(¶ms);
|
||||
return SkipLayerNormStaticSelection<T, U, V, Simplified>(¶ms);
|
||||
}
|
||||
|
||||
template Status LaunchSkipLayerNormKernel<float, float, float>(
|
||||
template Status LaunchSkipLayerNormKernel<float, float, float, true>(
|
||||
RocmTuningContext* tuning_ctx, Stream* stream, float* output, float* skip_input_bias_add_output, const float* input,
|
||||
const float* skip, const float* gamma, const float* beta,
|
||||
const float* bias, float epsilon, int ld,
|
||||
int element_count);
|
||||
|
||||
template Status LaunchSkipLayerNormKernel<half, float, half>(
|
||||
template Status LaunchSkipLayerNormKernel<half, float, half, true>(
|
||||
RocmTuningContext* tuning_ctx, Stream* stream, half* output, half* skip_input_bias_add_output, const half* input,
|
||||
const half* skip, const half* gamma, const half* beta,
|
||||
const half* bias, float epsilon, int ld,
|
||||
int element_count);
|
||||
|
||||
template Status LaunchSkipLayerNormKernel<float, float, float, false>(
|
||||
RocmTuningContext* tuning_ctx, Stream* stream, float* output, float* skip_input_bias_add_output, const float* input,
|
||||
const float* skip, const float* gamma, const float* beta,
|
||||
const float* bias, float epsilon, int ld,
|
||||
int element_count);
|
||||
|
||||
template Status LaunchSkipLayerNormKernel<half, float, half, false>(
|
||||
RocmTuningContext* tuning_ctx, Stream* stream, half* output, half* skip_input_bias_add_output, const half* input,
|
||||
const half* skip, const half* gamma, const half* beta,
|
||||
const half* bias, float epsilon, int ld,
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ namespace onnxruntime {
|
|||
namespace contrib {
|
||||
namespace rocm {
|
||||
|
||||
template <typename T, typename U, typename V>
|
||||
template <typename T, typename U, typename V, bool Simplified>
|
||||
Status LaunchSkipLayerNormKernel(
|
||||
RocmTuningContext* tuning,
|
||||
Stream* stream,
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ half maybe2half(float x) {
|
|||
return __float2half_rn(x);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename V, unsigned TPB>
|
||||
template <typename T, typename U, typename V, unsigned TPB, bool Simplified>
|
||||
__global__ void SkipLayerNormKernel(
|
||||
const int ld, const T* input, const T* skip, const V* beta, const V* gamma, const T* bias,
|
||||
const U epsilon, V* output, T* skip_input_bias_add_output) {
|
||||
|
|
@ -47,11 +47,16 @@ __global__ void SkipLayerNormKernel(
|
|||
output[idx] = static_cast<V>(val);
|
||||
}
|
||||
|
||||
if constexpr (Simplified) {
|
||||
SimplifiedLayerNorm<U, V, TPB>(thread_data.value, ld, offset, gamma, epsilon, output);
|
||||
return;
|
||||
}
|
||||
|
||||
LayerNorm<U, V, TPB>(thread_data, ld, offset, beta, gamma, epsilon, output);
|
||||
}
|
||||
|
||||
// Vectorized kernel
|
||||
template <typename T, typename U, typename V, unsigned TPB, int ILP>
|
||||
template <typename T, typename U, typename V, unsigned TPB, int ILP, bool Simplified>
|
||||
__global__ void SkipLayerNormKernelVec(
|
||||
const int ld, const T* input, const T* skip, const V* beta, const V* gamma,
|
||||
const T* bias, const U epsilon, V* output, T* skip_input_bias_add_output,
|
||||
|
|
@ -94,11 +99,16 @@ __global__ void SkipLayerNormKernelVec(
|
|||
}
|
||||
}
|
||||
|
||||
if constexpr (Simplified) {
|
||||
SimplifiedLayerNormVec<U, V, TPB, ILP>(thread_data.value, ld, offset, gamma, epsilon, output);
|
||||
return;
|
||||
}
|
||||
|
||||
LayerNormVec<U, V, TPB, ILP>(thread_data, ld, offset, beta, gamma, epsilon, output);
|
||||
}
|
||||
|
||||
// Vectorized kernel
|
||||
template <typename T, typename U, typename V, unsigned TPB, int ILP>
|
||||
template <typename T, typename U, typename V, unsigned TPB, int ILP, bool Simplified>
|
||||
__global__ void SkipLayerNormKernelSmall(
|
||||
const int ld, const T* input, const T* skip, const V* beta, const V* gamma,
|
||||
const T* bias, const U epsilon, V* output, T* skip_input_bias_add_output,
|
||||
|
|
@ -138,6 +148,12 @@ __global__ void SkipLayerNormKernelSmall(
|
|||
|
||||
thread_data = hipcub::KeyValuePair<U, U>(rldval_sum, rldvalsq_sum);
|
||||
}
|
||||
|
||||
if constexpr (Simplified) {
|
||||
SimplifiedLayerNormSmall<T, U, V, TPB, ILP>(input_v.val, thread_data.value, ld, idx, gamma, epsilon, output);
|
||||
return;
|
||||
}
|
||||
|
||||
LayerNormSmall<T, U, V, TPB, ILP>(input_v.val, thread_data, ld, idx, beta, gamma, epsilon, output);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -42,51 +42,51 @@ struct SkipLayerNormParams : OpParams {
|
|||
int element_count;
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename V, int ThreadsPerBlock, int VecSize>
|
||||
template <typename T, typename U, typename V, int ThreadsPerBlock, int VecSize, bool Simplified>
|
||||
Status SkipLayerNormSmallOp(const SkipLayerNormParams<T, V>* params) {
|
||||
// Loosen the hard constraint for ld (hidden_size) to include more possible *Small kernels,
|
||||
// which could offer better performance in some combinations of ThreadsPerBlock and VecSize.
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
|
||||
!((params->ld <= 8192 && params->ld % VecSize == 0 &&
|
||||
params->ld <= ThreadsPerBlock * VecSize && params->ld > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize)));
|
||||
SkipLayerNormKernelSmall<T, U, V, ThreadsPerBlock, VecSize><<<dim3(CeilDiv(params->element_count, params->ld)),
|
||||
dim3(ThreadsPerBlock),
|
||||
0, params->StreamHandle()>>>(
|
||||
SkipLayerNormKernelSmall<T, U, V, ThreadsPerBlock, VecSize, Simplified><<<dim3(CeilDiv(params->element_count, params->ld)),
|
||||
dim3(ThreadsPerBlock),
|
||||
0, params->StreamHandle()>>>(
|
||||
params->ld, params->input, params->skip,
|
||||
params->beta, params->gamma, params->bias, static_cast<U>(params->epsilon), params->output, params->skip_input_bias_add_output,
|
||||
(params->bias == nullptr) ? false : true, (params->skip_input_bias_add_output == nullptr) ? false : true);
|
||||
return HIP_CALL(hipGetLastError());
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename V, int ThreadsPerBlock, int VecSize>
|
||||
template <typename T, typename U, typename V, int ThreadsPerBlock, int VecSize, bool Simplified>
|
||||
Status SkipLayerNormRegularOp(const SkipLayerNormParams<T, V>* params) {
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
|
||||
!((params->ld > 0 && params->ld % VecSize == 0 &&
|
||||
(params->ld >= ThreadsPerBlock * VecSize ||
|
||||
(params->ld < GPU_WARP_SIZE && params->ld > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize)))));
|
||||
SkipLayerNormKernelVec<T, U, V, ThreadsPerBlock, VecSize><<<dim3(CeilDiv(params->element_count, params->ld)),
|
||||
dim3(ThreadsPerBlock),
|
||||
0, params->StreamHandle()>>>(
|
||||
SkipLayerNormKernelVec<T, U, V, ThreadsPerBlock, VecSize, Simplified><<<dim3(CeilDiv(params->element_count, params->ld)),
|
||||
dim3(ThreadsPerBlock),
|
||||
0, params->StreamHandle()>>>(
|
||||
params->ld, params->input, params->skip,
|
||||
params->beta, params->gamma, params->bias, static_cast<U>(params->epsilon), params->output, params->skip_input_bias_add_output,
|
||||
(params->bias == nullptr) ? false : true, (params->skip_input_bias_add_output == nullptr) ? false : true);
|
||||
return HIP_CALL(hipGetLastError());
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename V>
|
||||
template <typename T, typename U, typename V, bool Simplified>
|
||||
Status SkipLayerNormStaticSelection(const SkipLayerNormParams<T, V>* params) {
|
||||
bool hasBias = (params->bias == nullptr) ? false : true;
|
||||
bool hasSkipInputBiasAdditionOutput = (params->skip_input_bias_add_output == nullptr) ? false : true;
|
||||
const int grid_size = params->element_count / params->ld;
|
||||
const int block_size = 256;
|
||||
|
||||
#define LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(ELEMENTS, TPB, ILP) \
|
||||
if (params->ld <= ELEMENTS) { \
|
||||
SkipLayerNormKernelSmall<T, U, V, TPB, ILP><<<grid_size, TPB, 0, params->StreamHandle()>>>( \
|
||||
params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, \
|
||||
static_cast<U>(params->epsilon), params->output, params->skip_input_bias_add_output, \
|
||||
hasBias, hasSkipInputBiasAdditionOutput); \
|
||||
break; \
|
||||
#define LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(ELEMENTS, TPB, ILP) \
|
||||
if (params->ld <= ELEMENTS) { \
|
||||
SkipLayerNormKernelSmall<T, U, V, TPB, ILP, Simplified><<<grid_size, TPB, 0, params->StreamHandle()>>>( \
|
||||
params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, \
|
||||
static_cast<U>(params->epsilon), params->output, params->skip_input_bias_add_output, \
|
||||
hasBias, hasSkipInputBiasAdditionOutput); \
|
||||
break; \
|
||||
}
|
||||
if (0 == (params->ld % 4)) {
|
||||
do {
|
||||
|
|
@ -97,7 +97,7 @@ Status SkipLayerNormStaticSelection(const SkipLayerNormParams<T, V>* params) {
|
|||
LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(768, 192, 4)
|
||||
LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(1024, 256, 4)
|
||||
|
||||
SkipLayerNormKernel<T, U, V, block_size><<<grid_size, block_size, 0, params->StreamHandle()>>>(
|
||||
SkipLayerNormKernel<T, U, V, block_size, Simplified><<<grid_size, block_size, 0, params->StreamHandle()>>>(
|
||||
params->ld, params->input, params->skip, params->beta, params->gamma, params->bias,
|
||||
static_cast<U>(params->epsilon), params->output, params->skip_input_bias_add_output);
|
||||
} while (0);
|
||||
|
|
@ -108,7 +108,7 @@ Status SkipLayerNormStaticSelection(const SkipLayerNormParams<T, V>* params) {
|
|||
LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(128, 128, 1)
|
||||
LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(384, 384, 1)
|
||||
|
||||
SkipLayerNormKernel<T, U, V, block_size><<<grid_size, block_size, 0, params->StreamHandle()>>>(
|
||||
SkipLayerNormKernel<T, U, V, block_size, Simplified><<<grid_size, block_size, 0, params->StreamHandle()>>>(
|
||||
params->ld, params->input, params->skip, params->beta, params->gamma, params->bias,
|
||||
static_cast<U>(params->epsilon), params->output, params->skip_input_bias_add_output);
|
||||
} while (0);
|
||||
|
|
@ -116,12 +116,12 @@ Status SkipLayerNormStaticSelection(const SkipLayerNormParams<T, V>* params) {
|
|||
return HIP_CALL(hipPeekAtLastError());
|
||||
} // namespace rocm
|
||||
|
||||
#define ADD_OP_FOR_ALL_VEC_SIZE(name, threads_per_block) \
|
||||
this->RegisterOp(name<T, U, V, threads_per_block, 1>); \
|
||||
this->RegisterOp(name<T, U, V, threads_per_block, 2>); \
|
||||
this->RegisterOp(name<T, U, V, threads_per_block, 4>); \
|
||||
this->RegisterOp(name<T, U, V, threads_per_block, 8>); \
|
||||
this->RegisterOp(name<T, U, V, threads_per_block, 16>);
|
||||
#define ADD_OP_FOR_ALL_VEC_SIZE(name, threads_per_block) \
|
||||
this->RegisterOp(name<T, U, V, threads_per_block, 1, Simplified>); \
|
||||
this->RegisterOp(name<T, U, V, threads_per_block, 2, Simplified>); \
|
||||
this->RegisterOp(name<T, U, V, threads_per_block, 4, Simplified>); \
|
||||
this->RegisterOp(name<T, U, V, threads_per_block, 8, Simplified>); \
|
||||
this->RegisterOp(name<T, U, V, threads_per_block, 16, Simplified>);
|
||||
|
||||
#define ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(name) \
|
||||
ADD_OP_FOR_ALL_VEC_SIZE(name, 64) \
|
||||
|
|
@ -140,11 +140,11 @@ Status SkipLayerNormStaticSelection(const SkipLayerNormParams<T, V>* params) {
|
|||
ADD_OP_FOR_ALL_VEC_SIZE(name, 896) \
|
||||
ADD_OP_FOR_ALL_VEC_SIZE(name, 1024)
|
||||
|
||||
template <typename T, typename U, typename V>
|
||||
template <typename T, typename U, typename V, bool Simplified>
|
||||
class SkipLayerNormTunableOp : public TunableOp<SkipLayerNormParams<T, V>> {
|
||||
public:
|
||||
SkipLayerNormTunableOp() {
|
||||
this->RegisterOp(SkipLayerNormStaticSelection<T, U, V>);
|
||||
this->RegisterOp(SkipLayerNormStaticSelection<T, U, V, Simplified>);
|
||||
ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormSmallOp)
|
||||
ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormRegularOp)
|
||||
|
||||
|
|
|
|||
|
|
@ -84,6 +84,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SkipSimplifiedLayerNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ThresholdedRelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ThresholdedRelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ThresholdedRelu);
|
||||
|
|
@ -219,6 +221,8 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) {
|
|||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SkipSimplifiedLayerNormalization)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ThresholdedRelu)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ThresholdedRelu)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ThresholdedRelu)>,
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ namespace py = pybind11;
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
template <typename T, int ThreadsPerBlock, int VecSize>
|
||||
template <typename T, int ThreadsPerBlock, int VecSize, bool Simplified>
|
||||
class SkipLayerNormSmall : public IKernelExplorer {
|
||||
public:
|
||||
SkipLayerNormSmall(DeviceArray& output, DeviceArray& skip_input_bias_add_output, DeviceArray& input, DeviceArray& skip,
|
||||
|
|
@ -23,11 +23,11 @@ class SkipLayerNormSmall : public IKernelExplorer {
|
|||
static_cast<T*>(beta.ptr()), static_cast<T*>(bias.ptr()), epsilon, hidden_size, element_count) {}
|
||||
|
||||
void Run() override {
|
||||
ORT_THROW_IF_ERROR((contrib::rocm::SkipLayerNormSmallOp<T, float, T, ThreadsPerBlock, VecSize>(¶ms_)));
|
||||
ORT_THROW_IF_ERROR((contrib::rocm::SkipLayerNormSmallOp<T, float, T, ThreadsPerBlock, VecSize, Simplified>(¶ms_)));
|
||||
}
|
||||
|
||||
bool IsSupported() {
|
||||
Status status = contrib::rocm::SkipLayerNormSmallOp<T, float, T, ThreadsPerBlock, VecSize>(¶ms_);
|
||||
Status status = contrib::rocm::SkipLayerNormSmallOp<T, float, T, ThreadsPerBlock, VecSize, Simplified>(¶ms_);
|
||||
return status.IsOK();
|
||||
}
|
||||
|
||||
|
|
@ -36,7 +36,7 @@ class SkipLayerNormSmall : public IKernelExplorer {
|
|||
ParamsT params_{};
|
||||
};
|
||||
|
||||
template <typename T, int ThreadsPerBlock, int VecSize>
|
||||
template <typename T, int ThreadsPerBlock, int VecSize, bool Simplified>
|
||||
class SkipLayerNormRegular : public IKernelExplorer {
|
||||
public:
|
||||
SkipLayerNormRegular(DeviceArray& output, DeviceArray& skip_input_bias_add_output, DeviceArray& input, DeviceArray& skip,
|
||||
|
|
@ -47,11 +47,11 @@ class SkipLayerNormRegular : public IKernelExplorer {
|
|||
static_cast<T*>(beta.ptr()), static_cast<T*>(bias.ptr()), epsilon, hidden_size, element_count) {}
|
||||
|
||||
void Run() override {
|
||||
ORT_THROW_IF_ERROR((contrib::rocm::SkipLayerNormRegularOp<T, float, T, ThreadsPerBlock, VecSize>(¶ms_)));
|
||||
ORT_THROW_IF_ERROR((contrib::rocm::SkipLayerNormRegularOp<T, float, T, ThreadsPerBlock, VecSize, Simplified>(¶ms_)));
|
||||
}
|
||||
|
||||
bool IsSupported() {
|
||||
Status status = contrib::rocm::SkipLayerNormRegularOp<T, float, T, ThreadsPerBlock, VecSize>(¶ms_);
|
||||
Status status = contrib::rocm::SkipLayerNormRegularOp<T, float, T, ThreadsPerBlock, VecSize, Simplified>(¶ms_);
|
||||
return status.IsOK();
|
||||
}
|
||||
|
||||
|
|
@ -60,7 +60,7 @@ class SkipLayerNormRegular : public IKernelExplorer {
|
|||
ParamsT params_{};
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
template <typename T, bool Simplified>
|
||||
class SkipLayerNormStaticSelection : public IKernelExplorer {
|
||||
public:
|
||||
SkipLayerNormStaticSelection(DeviceArray& output, DeviceArray& skip_input_bias_add_output, DeviceArray& input,
|
||||
|
|
@ -71,11 +71,11 @@ class SkipLayerNormStaticSelection : public IKernelExplorer {
|
|||
static_cast<T*>(beta.ptr()), static_cast<T*>(bias.ptr()), epsilon, hidden_size, element_count) {}
|
||||
|
||||
void Run() override {
|
||||
ORT_THROW_IF_ERROR((contrib::rocm::SkipLayerNormStaticSelection<T, float, T>(¶ms_)));
|
||||
ORT_THROW_IF_ERROR((contrib::rocm::SkipLayerNormStaticSelection<T, float, T, Simplified>(¶ms_)));
|
||||
}
|
||||
|
||||
bool IsSupported() {
|
||||
Status status = contrib::rocm::SkipLayerNormStaticSelection<T, float, T>(¶ms_);
|
||||
Status status = contrib::rocm::SkipLayerNormStaticSelection<T, float, T, Simplified>(¶ms_);
|
||||
return status.IsOK();
|
||||
}
|
||||
|
||||
|
|
@ -84,7 +84,7 @@ class SkipLayerNormStaticSelection : public IKernelExplorer {
|
|||
ParamsT params_{};
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
template <typename T, bool Simplified>
|
||||
class SkipLayerNormTunable : public IKernelExplorer {
|
||||
public:
|
||||
SkipLayerNormTunable(DeviceArray& output, DeviceArray& skip_input_bias_add_output, DeviceArray& input, DeviceArray& skip,
|
||||
|
|
@ -107,18 +107,22 @@ class SkipLayerNormTunable : public IKernelExplorer {
|
|||
private:
|
||||
using ParamsT = contrib::rocm::SkipLayerNormParams<T, T>;
|
||||
ParamsT params_{};
|
||||
contrib::rocm::SkipLayerNormTunableOp<T, float, T> op_{};
|
||||
contrib::rocm::SkipLayerNormTunableOp<T, float, T, Simplified> op_{};
|
||||
};
|
||||
|
||||
#define REGISTER_OP(name, type, threads_per_block, vec_size) \
|
||||
py::class_<name<type, threads_per_block, vec_size>>(m, #name "_" #type "_" #threads_per_block "_" #vec_size) \
|
||||
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, \
|
||||
DeviceArray&, DeviceArray&, DeviceArray&, \
|
||||
float, int, int>()) \
|
||||
.def("SetRepeats", &name<type, threads_per_block, vec_size>::SetRepeats) \
|
||||
.def("Profile", &name<type, threads_per_block, vec_size>::Profile) \
|
||||
.def("Run", &name<type, threads_per_block, vec_size>::Run) \
|
||||
.def("IsSupported", &name<type, threads_per_block, vec_size>::IsSupported);
|
||||
#define REGISTER_OP_COMMON(name, type, ...) \
|
||||
py::class_<type<__VA_ARGS__>>(m, name) \
|
||||
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, \
|
||||
DeviceArray&, DeviceArray&, DeviceArray&, \
|
||||
float, int, int>()) \
|
||||
.def("SetRepeats", &type<__VA_ARGS__>::SetRepeats) \
|
||||
.def("Profile", &type<__VA_ARGS__>::Profile) \
|
||||
.def("Run", &type<__VA_ARGS__>::Run) \
|
||||
.def("IsSupported", &type<__VA_ARGS__>::IsSupported);
|
||||
|
||||
#define REGISTER_OP(name, type, threads_per_block, vec_size) \
|
||||
REGISTER_OP_COMMON("Simplified" #name "_" #type "_" #threads_per_block "_" #vec_size, name, type, threads_per_block, vec_size, true) \
|
||||
REGISTER_OP_COMMON(#name "_" #type "_" #threads_per_block "_" #vec_size, name, type, threads_per_block, vec_size, false)
|
||||
|
||||
#define REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, threads_per_block) \
|
||||
REGISTER_OP(name, type, threads_per_block, 1) \
|
||||
|
|
@ -144,19 +148,24 @@ class SkipLayerNormTunable : public IKernelExplorer {
|
|||
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 896) \
|
||||
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 1024)
|
||||
|
||||
#define REGISTER_OP_TYPED(name, type) \
|
||||
py::class_<name<type>>(m, #name "_" #type) \
|
||||
#define REGISTER_COMMON(name, type, ...) \
|
||||
py::class_<type<__VA_ARGS__>>(m, name) \
|
||||
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, \
|
||||
DeviceArray&, DeviceArray&, DeviceArray&, \
|
||||
float, int, int>()) \
|
||||
.def("SetRepeats", &name<type>::SetRepeats) \
|
||||
.def("Profile", &name<type>::Profile) \
|
||||
.def("Run", &name<type>::Run) \
|
||||
.def("IsSupported", &name<type>::IsSupported);
|
||||
.def("SetRepeats", &type<__VA_ARGS__>::SetRepeats) \
|
||||
.def("Profile", &type<__VA_ARGS__>::Profile) \
|
||||
.def("Run", &type<__VA_ARGS__>::Run) \
|
||||
.def("IsSupported", &type<__VA_ARGS__>::IsSupported);
|
||||
|
||||
#define REGISTER_OP_TYPED(name, type) \
|
||||
REGISTER_COMMON("Simplified" #name "_" #type, name, type, true) \
|
||||
REGISTER_COMMON(#name "_" #type, name, type, false)
|
||||
|
||||
KE_REGISTER(m) {
|
||||
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormSmall, half);
|
||||
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormSmall, float);
|
||||
|
||||
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormRegular, half);
|
||||
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormRegular, float);
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from itertools import product
|
|||
import kernel_explorer as ke
|
||||
import numpy as np
|
||||
import pytest
|
||||
from utils import dtype_to_bytes, standardization
|
||||
from utils import dtype_to_bytes, root_mean_square, standardization
|
||||
|
||||
|
||||
def get_bert_sizes_test():
|
||||
|
|
@ -28,10 +28,11 @@ def get_bert_sizes_profile():
|
|||
return product(batch_sizes, seq_lens, hidden_sizes)
|
||||
|
||||
|
||||
def dtype_to_funcs(dtype):
|
||||
def dtype_to_funcs(dtype, simplified=False):
|
||||
skip_layer_norm_prefix = "SimplifiedSkipLayerNorm" if simplified else "SkipLayerNorm"
|
||||
type_map = {
|
||||
"float16": list(filter(lambda x: re.search("SkipLayerNorm.*_half", x), dir(ke))),
|
||||
"float32": list(filter(lambda x: re.search("SkipLayerNorm.*_float", x), dir(ke))),
|
||||
"float16": list(filter(lambda x: re.match(f"{skip_layer_norm_prefix}.*_half.*", x), dir(ke))),
|
||||
"float32": list(filter(lambda x: re.match(f"{skip_layer_norm_prefix}.*_float.*", x), dir(ke))),
|
||||
}
|
||||
return type_map[dtype]
|
||||
|
||||
|
|
@ -43,7 +44,16 @@ def skip_layer_norm(input_x, skip, bias, gamma, beta, epsilon):
|
|||
return output, val
|
||||
|
||||
|
||||
def run_skip_layer_norm(batch_size: int, seq_len: int, hidden_size: int, dtype: str, func, has_optional_output=False):
|
||||
def simplified_skip_layer_norm(input_x, skip, bias, gamma, epsilon):
|
||||
val = input_x + skip + bias
|
||||
rms = root_mean_square(val, 2, epsilon)
|
||||
output = (val / rms) * gamma
|
||||
return output, val
|
||||
|
||||
|
||||
def run_skip_layer_norm(
|
||||
batch_size: int, seq_len: int, hidden_size: int, dtype: str, func, simplified=False, has_optional_output=False
|
||||
):
|
||||
np.random.seed(0)
|
||||
input_x = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
|
||||
skip = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
|
||||
|
|
@ -86,20 +96,25 @@ def run_skip_layer_norm(batch_size: int, seq_len: int, hidden_size: int, dtype:
|
|||
y_d.UpdateHostNumpyArray()
|
||||
optional_d.UpdateHostNumpyArray()
|
||||
|
||||
y_ref, y_optional = skip_layer_norm(input_x, skip, bias, gamma, beta, epsilon)
|
||||
if simplified:
|
||||
y_ref, y_optional = simplified_skip_layer_norm(input_x, skip, bias, gamma, epsilon)
|
||||
else:
|
||||
y_ref, y_optional = skip_layer_norm(input_x, skip, bias, gamma, beta, epsilon)
|
||||
np.testing.assert_almost_equal(y_ref, output_y, decimal=1)
|
||||
if has_optional_output:
|
||||
np.testing.assert_almost_equal(y_optional, output_optional, decimal=3)
|
||||
|
||||
|
||||
dtypes = ["float32", "float16"]
|
||||
simplified = [True, False]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("bert_sizes", get_bert_sizes_test())
|
||||
@pytest.mark.parametrize("dtype", dtypes)
|
||||
def test_skip_layer_norm(bert_sizes, dtype):
|
||||
for func in dtype_to_funcs(dtype):
|
||||
run_skip_layer_norm(*bert_sizes, dtype, func)
|
||||
@pytest.mark.parametrize("simplified", simplified)
|
||||
def test_skip_layer_norm(bert_sizes, dtype, simplified):
|
||||
for func in dtype_to_funcs(dtype, simplified):
|
||||
run_skip_layer_norm(*bert_sizes, dtype, func, simplified)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -160,9 +175,9 @@ def profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func,
|
|||
ke.report(SkipLayerNormMetric(func, dtype, duration_ms, total_bytes, batch_size, seq_len, hidden_size))
|
||||
|
||||
|
||||
def profile_with_args(batch_size, seq_len, hidden_size, dtype, sort=True, has_optional_output=False):
|
||||
def profile_with_args(batch_size, seq_len, hidden_size, dtype, sort=True, has_optional_output=False, simplified=False):
|
||||
with ke.benchmark(sort):
|
||||
for func in dtype_to_funcs(dtype):
|
||||
for func in dtype_to_funcs(dtype, simplified):
|
||||
profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func, has_optional_output)
|
||||
|
||||
|
||||
|
|
@ -184,11 +199,18 @@ if __name__ == "__main__":
|
|||
group.add_argument("dtype", choices=dtypes)
|
||||
group.add_argument("--sort", action="store_true")
|
||||
group.add_argument("--has_optional_output", "-o", action="store_true")
|
||||
group.add_argument("--simplified", "-s", action="store_true", default=False)
|
||||
|
||||
if len(sys.argv) == 1:
|
||||
profile()
|
||||
else:
|
||||
args = parser.parse_args()
|
||||
profile_with_args(
|
||||
args.batch_size, args.seq_len, args.hidden_size, args.dtype, args.sort, args.has_optional_output
|
||||
args.batch_size,
|
||||
args.seq_len,
|
||||
args.hidden_size,
|
||||
args.dtype,
|
||||
args.sort,
|
||||
args.has_optional_output,
|
||||
args.simplified,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -133,6 +133,11 @@ def relu(x, bias):
|
|||
return np.max(x, 0, keepdims=True)
|
||||
|
||||
|
||||
def root_mean_square(x, axis, epsilon):
|
||||
rms = np.sqrt(np.mean(np.square(x), axis=axis, keepdims=True) + epsilon)
|
||||
return rms
|
||||
|
||||
|
||||
def standardization(x, axis, epsilon):
|
||||
mean = np.mean(x, axis=axis, keepdims=True)
|
||||
variance = np.var(x, axis=axis, keepdims=True)
|
||||
|
|
|
|||
|
|
@ -698,8 +698,8 @@ TEST(SkipLayerNormTest, SkipLayerNormBatch2_TokenCount) {
|
|||
true);
|
||||
}
|
||||
|
||||
// SkipSimplifiedLayerNorm has not been enabled for ROCm and DML yet
|
||||
#if !defined(USE_ROCM) && !defined(USE_DML)
|
||||
// SkipSimplifiedLayerNorm has not been enabled for DML yet
|
||||
#if !defined(USE_DML)
|
||||
TEST(SkipLayerNormTest, SkipSimplifiedLayerNormBatch1_Float16) {
|
||||
int batch_size = 1;
|
||||
int sequence_length = 2;
|
||||
|
|
@ -735,7 +735,9 @@ TEST(SkipLayerNormTest, SkipSimplifiedLayerNormBatch1_Float16) {
|
|||
true,
|
||||
true);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if !defined(USE_ROCM) && !defined(USE_DML)
|
||||
TEST(SkipLayerNormTest, SkipLayerNormBatch2_Skip_Broadcast_No_Batch_Size) {
|
||||
int batch_size = 2;
|
||||
int sequence_length = 2;
|
||||
|
|
|
|||
Loading…
Reference in a new issue