[ROCm] add SimplifiedSkipLayerNorm implementation (#17213)

add SimplifiedSkipLayerNorm implementation
This commit is contained in:
PeixuanZuo 2023-08-22 12:06:58 +08:00 committed by GitHub
parent 4e6cec4d09
commit d5c565156d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 249 additions and 84 deletions

View file

@ -109,6 +109,61 @@ __device__ inline void LayerNorm(
}
}
template <typename U, typename V, int TPB>
__device__ inline void SimplifiedLayerNorm(
const U& thread_data, const int ld, const int offset, const V* gamma, const U epsilon, V* output) {
// Assuming thread_data is already divided by ld
using BlockReduce = hipcub::BlockReduce<U, TPB>;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ U rsigma; // 1 / std.dev.
const U sum = BlockReduce(temp_storage).Sum(thread_data);
if (threadIdx.x == 0) {
rsigma = Rsqrt(sum + epsilon);
}
__syncthreads();
for (int i = threadIdx.x; i < ld; i += TPB) {
const int idx = offset + i;
const U val = static_cast<U>(output[idx]);
const U g = static_cast<U>(gamma[i]);
output[idx] = static_cast<V>(g * val * rsigma);
}
}
template <typename U, typename V, int TPB, int ILP>
__device__ inline void SimplifiedLayerNormVec(
const U& thread_data, const int ld, const int offset, const V* gamma, const U epsilon, V* output) {
// Assuming thread_data is already divided by ld
using VecV = aligned_vector<V, ILP>;
using BlockReduce = hipcub::BlockReduce<U, TPB>;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ U rsigma; // 1 / std.dev.
const U sum = BlockReduce(temp_storage).Sum(thread_data);
if (threadIdx.x == 0) {
rsigma = Rsqrt(sum + epsilon);
}
__syncthreads();
if (ILP * threadIdx.x < ld) {
for (int i = threadIdx.x * ILP; i < ld; i += TPB * ILP) {
int idx = offset + i;
const VecV gamma_v = *reinterpret_cast<const VecV*>(gamma + i);
VecV output_v = *reinterpret_cast<const VecV*>(output + idx);
#pragma unroll
for (int k = 0; k < ILP; k++) {
output_v.val[k] = U(gamma_v.val[k]) * U(output_v.val[k]) * rsigma;
}
*(reinterpret_cast<VecV*>(output + idx)) = output_v;
}
}
}
template <typename U, typename V, int TPB, int ILP>
__device__ inline void LayerNormVec(
const hipcub::KeyValuePair<U, U>& thread_data, const int ld, const int offset, const V* beta,
@ -182,6 +237,36 @@ __device__ inline void LayerNormSmall(const T* input_v, const hipcub::KeyValuePa
}
}
template <typename T, typename U, typename V, int TPB, int ILP>
__device__ inline void SimplifiedLayerNormSmall(const T* input_v, const U& thread_data, const int ld, const int idx,
const V* gamma, const U epsilon, V* output) {
// Assuming thread_data is already divided by ld
// Small settings: the block covers the leading dimension TPB >= ld. The input
// value is available in a register
using VecV = aligned_vector<V, ILP>;
using BlockReduce = hipcub::BlockReduce<U, TPB>;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ U rsigma; // 1 / std.dev.
const U sum = BlockReduce(temp_storage).Sum(thread_data);
if (threadIdx.x == 0) {
rsigma = Rsqrt(sum + epsilon);
}
__syncthreads();
if (ILP * threadIdx.x < ld) {
const VecV gamma_v = *reinterpret_cast<const VecV*>(gamma + threadIdx.x * ILP);
VecV output_v;
#pragma unroll
for (int i = 0; i < ILP; i++) {
output_v.val[i] = U(gamma_v.val[i]) * U(input_v[i]) * rsigma;
}
*(reinterpret_cast<VecV*>(output + idx)) = output_v;
}
}
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime

View file

@ -20,26 +20,36 @@ namespace rocm {
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
SkipLayerNorm<T>);
SkipLayerNorm<T, false>); \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
SkipSimplifiedLayerNormalization, \
kMSDomain, \
1, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
SkipLayerNorm<T, true>);
REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
using namespace ONNX_NAMESPACE;
template <typename T>
SkipLayerNorm<T>::SkipLayerNorm(const OpKernelInfo& op_kernel_info) : RocmKernel(op_kernel_info) {
template <typename T, bool Simplified>
SkipLayerNorm<T, Simplified>::SkipLayerNorm(const OpKernelInfo& op_kernel_info) : RocmKernel(op_kernel_info) {
ORT_ENFORCE(op_kernel_info.GetAttr<float>("epsilon", &epsilon_).IsOK());
ORT_ENFORCE(epsilon_ >= 0);
}
template <typename T>
Status SkipLayerNorm<T>::ComputeInternal(OpKernelContext* ctx) const {
template <typename T, bool Simplified>
Status SkipLayerNorm<T, Simplified>::ComputeInternal(OpKernelContext* ctx) const {
const Tensor* input = ctx->Input<Tensor>(0);
const Tensor* skip = ctx->Input<Tensor>(1);
const Tensor* gamma = ctx->Input<Tensor>(2);
const Tensor* beta = ctx->Input<Tensor>(3);
const Tensor* bias = ctx->Input<Tensor>(4);
const Tensor* beta = Simplified ? nullptr : ctx->Input<Tensor>(3);
const Tensor* bias = Simplified ? ctx->Input<Tensor>(3) : ctx->Input<Tensor>(4);
Tensor* output = ctx->Output(0, input->Shape());
@ -102,7 +112,7 @@ Status SkipLayerNorm<T>::ComputeInternal(OpKernelContext* ctx) const {
int64_t element_count = input->Shape().Size();
typedef typename ToHipType<T>::MappedType HipT;
return LaunchSkipLayerNormKernel<HipT, float, HipT>(
return LaunchSkipLayerNormKernel<HipT, float, HipT, Simplified>(
GetTuningContext(),
ctx->GetComputeStream(),
reinterpret_cast<HipT*>(output->MutableData<T>()),

View file

@ -11,7 +11,7 @@ namespace rocm {
using namespace onnxruntime::rocm;
template <typename T>
template <typename T, bool Simplified>
class SkipLayerNorm final : public RocmKernel {
public:
SkipLayerNorm(const OpKernelInfo& op_kernel_info);

View file

@ -39,7 +39,7 @@ namespace onnxruntime {
namespace contrib {
namespace rocm {
template <typename T, typename U, typename V>
template <typename T, typename U, typename V, bool Simplified>
Status LaunchSkipLayerNormKernel(
RocmTuningContext* tuning_ctx, Stream* stream, V* output, T* skip_input_bias_add_output, const T* input,
const T* skip, const V* gamma, const V* beta, const T* bias, float epsilon, int ld, int element_count) {
@ -50,20 +50,32 @@ Status LaunchSkipLayerNormKernel(
gamma, beta, bias, epsilon, ld, element_count);
if (tuning_ctx->IsTunableOpEnabled()) {
static SkipLayerNormTunableOp<T, U, V> op;
static SkipLayerNormTunableOp<T, U, V, Simplified> op;
return op(&params);
}
return SkipLayerNormStaticSelection<T, U, V>(&params);
return SkipLayerNormStaticSelection<T, U, V, Simplified>(&params);
}
template Status LaunchSkipLayerNormKernel<float, float, float>(
template Status LaunchSkipLayerNormKernel<float, float, float, true>(
RocmTuningContext* tuning_ctx, Stream* stream, float* output, float* skip_input_bias_add_output, const float* input,
const float* skip, const float* gamma, const float* beta,
const float* bias, float epsilon, int ld,
int element_count);
template Status LaunchSkipLayerNormKernel<half, float, half>(
template Status LaunchSkipLayerNormKernel<half, float, half, true>(
RocmTuningContext* tuning_ctx, Stream* stream, half* output, half* skip_input_bias_add_output, const half* input,
const half* skip, const half* gamma, const half* beta,
const half* bias, float epsilon, int ld,
int element_count);
template Status LaunchSkipLayerNormKernel<float, float, float, false>(
RocmTuningContext* tuning_ctx, Stream* stream, float* output, float* skip_input_bias_add_output, const float* input,
const float* skip, const float* gamma, const float* beta,
const float* bias, float epsilon, int ld,
int element_count);
template Status LaunchSkipLayerNormKernel<half, float, half, false>(
RocmTuningContext* tuning_ctx, Stream* stream, half* output, half* skip_input_bias_add_output, const half* input,
const half* skip, const half* gamma, const half* beta,
const half* bias, float epsilon, int ld,

View file

@ -10,7 +10,7 @@ namespace onnxruntime {
namespace contrib {
namespace rocm {
template <typename T, typename U, typename V>
template <typename T, typename U, typename V, bool Simplified>
Status LaunchSkipLayerNormKernel(
RocmTuningContext* tuning,
Stream* stream,

View file

@ -23,7 +23,7 @@ half maybe2half(float x) {
return __float2half_rn(x);
}
template <typename T, typename U, typename V, unsigned TPB>
template <typename T, typename U, typename V, unsigned TPB, bool Simplified>
__global__ void SkipLayerNormKernel(
const int ld, const T* input, const T* skip, const V* beta, const V* gamma, const T* bias,
const U epsilon, V* output, T* skip_input_bias_add_output) {
@ -47,11 +47,16 @@ __global__ void SkipLayerNormKernel(
output[idx] = static_cast<V>(val);
}
if constexpr (Simplified) {
SimplifiedLayerNorm<U, V, TPB>(thread_data.value, ld, offset, gamma, epsilon, output);
return;
}
LayerNorm<U, V, TPB>(thread_data, ld, offset, beta, gamma, epsilon, output);
}
// Vectorized kernel
template <typename T, typename U, typename V, unsigned TPB, int ILP>
template <typename T, typename U, typename V, unsigned TPB, int ILP, bool Simplified>
__global__ void SkipLayerNormKernelVec(
const int ld, const T* input, const T* skip, const V* beta, const V* gamma,
const T* bias, const U epsilon, V* output, T* skip_input_bias_add_output,
@ -94,11 +99,16 @@ __global__ void SkipLayerNormKernelVec(
}
}
if constexpr (Simplified) {
SimplifiedLayerNormVec<U, V, TPB, ILP>(thread_data.value, ld, offset, gamma, epsilon, output);
return;
}
LayerNormVec<U, V, TPB, ILP>(thread_data, ld, offset, beta, gamma, epsilon, output);
}
// Vectorized kernel
template <typename T, typename U, typename V, unsigned TPB, int ILP>
template <typename T, typename U, typename V, unsigned TPB, int ILP, bool Simplified>
__global__ void SkipLayerNormKernelSmall(
const int ld, const T* input, const T* skip, const V* beta, const V* gamma,
const T* bias, const U epsilon, V* output, T* skip_input_bias_add_output,
@ -138,6 +148,12 @@ __global__ void SkipLayerNormKernelSmall(
thread_data = hipcub::KeyValuePair<U, U>(rldval_sum, rldvalsq_sum);
}
if constexpr (Simplified) {
SimplifiedLayerNormSmall<T, U, V, TPB, ILP>(input_v.val, thread_data.value, ld, idx, gamma, epsilon, output);
return;
}
LayerNormSmall<T, U, V, TPB, ILP>(input_v.val, thread_data, ld, idx, beta, gamma, epsilon, output);
}

View file

@ -42,51 +42,51 @@ struct SkipLayerNormParams : OpParams {
int element_count;
};
template <typename T, typename U, typename V, int ThreadsPerBlock, int VecSize>
template <typename T, typename U, typename V, int ThreadsPerBlock, int VecSize, bool Simplified>
Status SkipLayerNormSmallOp(const SkipLayerNormParams<T, V>* params) {
// Loosen the hard constraint for ld (hidden_size) to include more possible *Small kernels,
// which could offer better performance in some combinations of ThreadsPerBlock and VecSize.
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
!((params->ld <= 8192 && params->ld % VecSize == 0 &&
params->ld <= ThreadsPerBlock * VecSize && params->ld > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize)));
SkipLayerNormKernelSmall<T, U, V, ThreadsPerBlock, VecSize><<<dim3(CeilDiv(params->element_count, params->ld)),
dim3(ThreadsPerBlock),
0, params->StreamHandle()>>>(
SkipLayerNormKernelSmall<T, U, V, ThreadsPerBlock, VecSize, Simplified><<<dim3(CeilDiv(params->element_count, params->ld)),
dim3(ThreadsPerBlock),
0, params->StreamHandle()>>>(
params->ld, params->input, params->skip,
params->beta, params->gamma, params->bias, static_cast<U>(params->epsilon), params->output, params->skip_input_bias_add_output,
(params->bias == nullptr) ? false : true, (params->skip_input_bias_add_output == nullptr) ? false : true);
return HIP_CALL(hipGetLastError());
}
template <typename T, typename U, typename V, int ThreadsPerBlock, int VecSize>
template <typename T, typename U, typename V, int ThreadsPerBlock, int VecSize, bool Simplified>
Status SkipLayerNormRegularOp(const SkipLayerNormParams<T, V>* params) {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
!((params->ld > 0 && params->ld % VecSize == 0 &&
(params->ld >= ThreadsPerBlock * VecSize ||
(params->ld < GPU_WARP_SIZE && params->ld > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize)))));
SkipLayerNormKernelVec<T, U, V, ThreadsPerBlock, VecSize><<<dim3(CeilDiv(params->element_count, params->ld)),
dim3(ThreadsPerBlock),
0, params->StreamHandle()>>>(
SkipLayerNormKernelVec<T, U, V, ThreadsPerBlock, VecSize, Simplified><<<dim3(CeilDiv(params->element_count, params->ld)),
dim3(ThreadsPerBlock),
0, params->StreamHandle()>>>(
params->ld, params->input, params->skip,
params->beta, params->gamma, params->bias, static_cast<U>(params->epsilon), params->output, params->skip_input_bias_add_output,
(params->bias == nullptr) ? false : true, (params->skip_input_bias_add_output == nullptr) ? false : true);
return HIP_CALL(hipGetLastError());
}
template <typename T, typename U, typename V>
template <typename T, typename U, typename V, bool Simplified>
Status SkipLayerNormStaticSelection(const SkipLayerNormParams<T, V>* params) {
bool hasBias = (params->bias == nullptr) ? false : true;
bool hasSkipInputBiasAdditionOutput = (params->skip_input_bias_add_output == nullptr) ? false : true;
const int grid_size = params->element_count / params->ld;
const int block_size = 256;
#define LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(ELEMENTS, TPB, ILP) \
if (params->ld <= ELEMENTS) { \
SkipLayerNormKernelSmall<T, U, V, TPB, ILP><<<grid_size, TPB, 0, params->StreamHandle()>>>( \
params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, \
static_cast<U>(params->epsilon), params->output, params->skip_input_bias_add_output, \
hasBias, hasSkipInputBiasAdditionOutput); \
break; \
#define LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(ELEMENTS, TPB, ILP) \
if (params->ld <= ELEMENTS) { \
SkipLayerNormKernelSmall<T, U, V, TPB, ILP, Simplified><<<grid_size, TPB, 0, params->StreamHandle()>>>( \
params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, \
static_cast<U>(params->epsilon), params->output, params->skip_input_bias_add_output, \
hasBias, hasSkipInputBiasAdditionOutput); \
break; \
}
if (0 == (params->ld % 4)) {
do {
@ -97,7 +97,7 @@ Status SkipLayerNormStaticSelection(const SkipLayerNormParams<T, V>* params) {
LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(768, 192, 4)
LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(1024, 256, 4)
SkipLayerNormKernel<T, U, V, block_size><<<grid_size, block_size, 0, params->StreamHandle()>>>(
SkipLayerNormKernel<T, U, V, block_size, Simplified><<<grid_size, block_size, 0, params->StreamHandle()>>>(
params->ld, params->input, params->skip, params->beta, params->gamma, params->bias,
static_cast<U>(params->epsilon), params->output, params->skip_input_bias_add_output);
} while (0);
@ -108,7 +108,7 @@ Status SkipLayerNormStaticSelection(const SkipLayerNormParams<T, V>* params) {
LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(128, 128, 1)
LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(384, 384, 1)
SkipLayerNormKernel<T, U, V, block_size><<<grid_size, block_size, 0, params->StreamHandle()>>>(
SkipLayerNormKernel<T, U, V, block_size, Simplified><<<grid_size, block_size, 0, params->StreamHandle()>>>(
params->ld, params->input, params->skip, params->beta, params->gamma, params->bias,
static_cast<U>(params->epsilon), params->output, params->skip_input_bias_add_output);
} while (0);
@ -116,12 +116,12 @@ Status SkipLayerNormStaticSelection(const SkipLayerNormParams<T, V>* params) {
return HIP_CALL(hipPeekAtLastError());
} // namespace rocm
#define ADD_OP_FOR_ALL_VEC_SIZE(name, threads_per_block) \
this->RegisterOp(name<T, U, V, threads_per_block, 1>); \
this->RegisterOp(name<T, U, V, threads_per_block, 2>); \
this->RegisterOp(name<T, U, V, threads_per_block, 4>); \
this->RegisterOp(name<T, U, V, threads_per_block, 8>); \
this->RegisterOp(name<T, U, V, threads_per_block, 16>);
#define ADD_OP_FOR_ALL_VEC_SIZE(name, threads_per_block) \
this->RegisterOp(name<T, U, V, threads_per_block, 1, Simplified>); \
this->RegisterOp(name<T, U, V, threads_per_block, 2, Simplified>); \
this->RegisterOp(name<T, U, V, threads_per_block, 4, Simplified>); \
this->RegisterOp(name<T, U, V, threads_per_block, 8, Simplified>); \
this->RegisterOp(name<T, U, V, threads_per_block, 16, Simplified>);
#define ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(name) \
ADD_OP_FOR_ALL_VEC_SIZE(name, 64) \
@ -140,11 +140,11 @@ Status SkipLayerNormStaticSelection(const SkipLayerNormParams<T, V>* params) {
ADD_OP_FOR_ALL_VEC_SIZE(name, 896) \
ADD_OP_FOR_ALL_VEC_SIZE(name, 1024)
template <typename T, typename U, typename V>
template <typename T, typename U, typename V, bool Simplified>
class SkipLayerNormTunableOp : public TunableOp<SkipLayerNormParams<T, V>> {
public:
SkipLayerNormTunableOp() {
this->RegisterOp(SkipLayerNormStaticSelection<T, U, V>);
this->RegisterOp(SkipLayerNormStaticSelection<T, U, V, Simplified>);
ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormSmallOp)
ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormRegularOp)

View file

@ -84,6 +84,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SkipSimplifiedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ThresholdedRelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ThresholdedRelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ThresholdedRelu);
@ -219,6 +221,8 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) {
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SkipSimplifiedLayerNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ThresholdedRelu)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ThresholdedRelu)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ThresholdedRelu)>,

View file

@ -12,7 +12,7 @@ namespace py = pybind11;
namespace onnxruntime {
template <typename T, int ThreadsPerBlock, int VecSize>
template <typename T, int ThreadsPerBlock, int VecSize, bool Simplified>
class SkipLayerNormSmall : public IKernelExplorer {
public:
SkipLayerNormSmall(DeviceArray& output, DeviceArray& skip_input_bias_add_output, DeviceArray& input, DeviceArray& skip,
@ -23,11 +23,11 @@ class SkipLayerNormSmall : public IKernelExplorer {
static_cast<T*>(beta.ptr()), static_cast<T*>(bias.ptr()), epsilon, hidden_size, element_count) {}
void Run() override {
ORT_THROW_IF_ERROR((contrib::rocm::SkipLayerNormSmallOp<T, float, T, ThreadsPerBlock, VecSize>(&params_)));
ORT_THROW_IF_ERROR((contrib::rocm::SkipLayerNormSmallOp<T, float, T, ThreadsPerBlock, VecSize, Simplified>(&params_)));
}
bool IsSupported() {
Status status = contrib::rocm::SkipLayerNormSmallOp<T, float, T, ThreadsPerBlock, VecSize>(&params_);
Status status = contrib::rocm::SkipLayerNormSmallOp<T, float, T, ThreadsPerBlock, VecSize, Simplified>(&params_);
return status.IsOK();
}
@ -36,7 +36,7 @@ class SkipLayerNormSmall : public IKernelExplorer {
ParamsT params_{};
};
template <typename T, int ThreadsPerBlock, int VecSize>
template <typename T, int ThreadsPerBlock, int VecSize, bool Simplified>
class SkipLayerNormRegular : public IKernelExplorer {
public:
SkipLayerNormRegular(DeviceArray& output, DeviceArray& skip_input_bias_add_output, DeviceArray& input, DeviceArray& skip,
@ -47,11 +47,11 @@ class SkipLayerNormRegular : public IKernelExplorer {
static_cast<T*>(beta.ptr()), static_cast<T*>(bias.ptr()), epsilon, hidden_size, element_count) {}
void Run() override {
ORT_THROW_IF_ERROR((contrib::rocm::SkipLayerNormRegularOp<T, float, T, ThreadsPerBlock, VecSize>(&params_)));
ORT_THROW_IF_ERROR((contrib::rocm::SkipLayerNormRegularOp<T, float, T, ThreadsPerBlock, VecSize, Simplified>(&params_)));
}
bool IsSupported() {
Status status = contrib::rocm::SkipLayerNormRegularOp<T, float, T, ThreadsPerBlock, VecSize>(&params_);
Status status = contrib::rocm::SkipLayerNormRegularOp<T, float, T, ThreadsPerBlock, VecSize, Simplified>(&params_);
return status.IsOK();
}
@ -60,7 +60,7 @@ class SkipLayerNormRegular : public IKernelExplorer {
ParamsT params_{};
};
template <typename T>
template <typename T, bool Simplified>
class SkipLayerNormStaticSelection : public IKernelExplorer {
public:
SkipLayerNormStaticSelection(DeviceArray& output, DeviceArray& skip_input_bias_add_output, DeviceArray& input,
@ -71,11 +71,11 @@ class SkipLayerNormStaticSelection : public IKernelExplorer {
static_cast<T*>(beta.ptr()), static_cast<T*>(bias.ptr()), epsilon, hidden_size, element_count) {}
void Run() override {
ORT_THROW_IF_ERROR((contrib::rocm::SkipLayerNormStaticSelection<T, float, T>(&params_)));
ORT_THROW_IF_ERROR((contrib::rocm::SkipLayerNormStaticSelection<T, float, T, Simplified>(&params_)));
}
bool IsSupported() {
Status status = contrib::rocm::SkipLayerNormStaticSelection<T, float, T>(&params_);
Status status = contrib::rocm::SkipLayerNormStaticSelection<T, float, T, Simplified>(&params_);
return status.IsOK();
}
@ -84,7 +84,7 @@ class SkipLayerNormStaticSelection : public IKernelExplorer {
ParamsT params_{};
};
template <typename T>
template <typename T, bool Simplified>
class SkipLayerNormTunable : public IKernelExplorer {
public:
SkipLayerNormTunable(DeviceArray& output, DeviceArray& skip_input_bias_add_output, DeviceArray& input, DeviceArray& skip,
@ -107,18 +107,22 @@ class SkipLayerNormTunable : public IKernelExplorer {
private:
using ParamsT = contrib::rocm::SkipLayerNormParams<T, T>;
ParamsT params_{};
contrib::rocm::SkipLayerNormTunableOp<T, float, T> op_{};
contrib::rocm::SkipLayerNormTunableOp<T, float, T, Simplified> op_{};
};
#define REGISTER_OP(name, type, threads_per_block, vec_size) \
py::class_<name<type, threads_per_block, vec_size>>(m, #name "_" #type "_" #threads_per_block "_" #vec_size) \
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, \
DeviceArray&, DeviceArray&, DeviceArray&, \
float, int, int>()) \
.def("SetRepeats", &name<type, threads_per_block, vec_size>::SetRepeats) \
.def("Profile", &name<type, threads_per_block, vec_size>::Profile) \
.def("Run", &name<type, threads_per_block, vec_size>::Run) \
.def("IsSupported", &name<type, threads_per_block, vec_size>::IsSupported);
#define REGISTER_OP_COMMON(name, type, ...) \
py::class_<type<__VA_ARGS__>>(m, name) \
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, \
DeviceArray&, DeviceArray&, DeviceArray&, \
float, int, int>()) \
.def("SetRepeats", &type<__VA_ARGS__>::SetRepeats) \
.def("Profile", &type<__VA_ARGS__>::Profile) \
.def("Run", &type<__VA_ARGS__>::Run) \
.def("IsSupported", &type<__VA_ARGS__>::IsSupported);
#define REGISTER_OP(name, type, threads_per_block, vec_size) \
REGISTER_OP_COMMON("Simplified" #name "_" #type "_" #threads_per_block "_" #vec_size, name, type, threads_per_block, vec_size, true) \
REGISTER_OP_COMMON(#name "_" #type "_" #threads_per_block "_" #vec_size, name, type, threads_per_block, vec_size, false)
#define REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, threads_per_block) \
REGISTER_OP(name, type, threads_per_block, 1) \
@ -144,19 +148,24 @@ class SkipLayerNormTunable : public IKernelExplorer {
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 896) \
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 1024)
#define REGISTER_OP_TYPED(name, type) \
py::class_<name<type>>(m, #name "_" #type) \
#define REGISTER_COMMON(name, type, ...) \
py::class_<type<__VA_ARGS__>>(m, name) \
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, \
DeviceArray&, DeviceArray&, DeviceArray&, \
float, int, int>()) \
.def("SetRepeats", &name<type>::SetRepeats) \
.def("Profile", &name<type>::Profile) \
.def("Run", &name<type>::Run) \
.def("IsSupported", &name<type>::IsSupported);
.def("SetRepeats", &type<__VA_ARGS__>::SetRepeats) \
.def("Profile", &type<__VA_ARGS__>::Profile) \
.def("Run", &type<__VA_ARGS__>::Run) \
.def("IsSupported", &type<__VA_ARGS__>::IsSupported);
#define REGISTER_OP_TYPED(name, type) \
REGISTER_COMMON("Simplified" #name "_" #type, name, type, true) \
REGISTER_COMMON(#name "_" #type, name, type, false)
KE_REGISTER(m) {
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormSmall, half);
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormSmall, float);
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormRegular, half);
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormRegular, float);

View file

@ -11,7 +11,7 @@ from itertools import product
import kernel_explorer as ke
import numpy as np
import pytest
from utils import dtype_to_bytes, standardization
from utils import dtype_to_bytes, root_mean_square, standardization
def get_bert_sizes_test():
@ -28,10 +28,11 @@ def get_bert_sizes_profile():
return product(batch_sizes, seq_lens, hidden_sizes)
def dtype_to_funcs(dtype):
def dtype_to_funcs(dtype, simplified=False):
skip_layer_norm_prefix = "SimplifiedSkipLayerNorm" if simplified else "SkipLayerNorm"
type_map = {
"float16": list(filter(lambda x: re.search("SkipLayerNorm.*_half", x), dir(ke))),
"float32": list(filter(lambda x: re.search("SkipLayerNorm.*_float", x), dir(ke))),
"float16": list(filter(lambda x: re.match(f"{skip_layer_norm_prefix}.*_half.*", x), dir(ke))),
"float32": list(filter(lambda x: re.match(f"{skip_layer_norm_prefix}.*_float.*", x), dir(ke))),
}
return type_map[dtype]
@ -43,7 +44,16 @@ def skip_layer_norm(input_x, skip, bias, gamma, beta, epsilon):
return output, val
def run_skip_layer_norm(batch_size: int, seq_len: int, hidden_size: int, dtype: str, func, has_optional_output=False):
def simplified_skip_layer_norm(input_x, skip, bias, gamma, epsilon):
val = input_x + skip + bias
rms = root_mean_square(val, 2, epsilon)
output = (val / rms) * gamma
return output, val
def run_skip_layer_norm(
batch_size: int, seq_len: int, hidden_size: int, dtype: str, func, simplified=False, has_optional_output=False
):
np.random.seed(0)
input_x = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
skip = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
@ -86,20 +96,25 @@ def run_skip_layer_norm(batch_size: int, seq_len: int, hidden_size: int, dtype:
y_d.UpdateHostNumpyArray()
optional_d.UpdateHostNumpyArray()
y_ref, y_optional = skip_layer_norm(input_x, skip, bias, gamma, beta, epsilon)
if simplified:
y_ref, y_optional = simplified_skip_layer_norm(input_x, skip, bias, gamma, epsilon)
else:
y_ref, y_optional = skip_layer_norm(input_x, skip, bias, gamma, beta, epsilon)
np.testing.assert_almost_equal(y_ref, output_y, decimal=1)
if has_optional_output:
np.testing.assert_almost_equal(y_optional, output_optional, decimal=3)
dtypes = ["float32", "float16"]
simplified = [True, False]
@pytest.mark.parametrize("bert_sizes", get_bert_sizes_test())
@pytest.mark.parametrize("dtype", dtypes)
def test_skip_layer_norm(bert_sizes, dtype):
for func in dtype_to_funcs(dtype):
run_skip_layer_norm(*bert_sizes, dtype, func)
@pytest.mark.parametrize("simplified", simplified)
def test_skip_layer_norm(bert_sizes, dtype, simplified):
for func in dtype_to_funcs(dtype, simplified):
run_skip_layer_norm(*bert_sizes, dtype, func, simplified)
@dataclass
@ -160,9 +175,9 @@ def profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func,
ke.report(SkipLayerNormMetric(func, dtype, duration_ms, total_bytes, batch_size, seq_len, hidden_size))
def profile_with_args(batch_size, seq_len, hidden_size, dtype, sort=True, has_optional_output=False):
def profile_with_args(batch_size, seq_len, hidden_size, dtype, sort=True, has_optional_output=False, simplified=False):
with ke.benchmark(sort):
for func in dtype_to_funcs(dtype):
for func in dtype_to_funcs(dtype, simplified):
profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func, has_optional_output)
@ -184,11 +199,18 @@ if __name__ == "__main__":
group.add_argument("dtype", choices=dtypes)
group.add_argument("--sort", action="store_true")
group.add_argument("--has_optional_output", "-o", action="store_true")
group.add_argument("--simplified", "-s", action="store_true", default=False)
if len(sys.argv) == 1:
profile()
else:
args = parser.parse_args()
profile_with_args(
args.batch_size, args.seq_len, args.hidden_size, args.dtype, args.sort, args.has_optional_output
args.batch_size,
args.seq_len,
args.hidden_size,
args.dtype,
args.sort,
args.has_optional_output,
args.simplified,
)

View file

@ -133,6 +133,11 @@ def relu(x, bias):
return np.max(x, 0, keepdims=True)
def root_mean_square(x, axis, epsilon):
rms = np.sqrt(np.mean(np.square(x), axis=axis, keepdims=True) + epsilon)
return rms
def standardization(x, axis, epsilon):
mean = np.mean(x, axis=axis, keepdims=True)
variance = np.var(x, axis=axis, keepdims=True)

View file

@ -698,8 +698,8 @@ TEST(SkipLayerNormTest, SkipLayerNormBatch2_TokenCount) {
true);
}
// SkipSimplifiedLayerNorm has not been enabled for ROCm and DML yet
#if !defined(USE_ROCM) && !defined(USE_DML)
// SkipSimplifiedLayerNorm has not been enabled for DML yet
#if !defined(USE_DML)
TEST(SkipLayerNormTest, SkipSimplifiedLayerNormBatch1_Float16) {
int batch_size = 1;
int sequence_length = 2;
@ -735,7 +735,9 @@ TEST(SkipLayerNormTest, SkipSimplifiedLayerNormBatch1_Float16) {
true,
true);
}
#endif
#if !defined(USE_ROCM) && !defined(USE_DML)
TEST(SkipLayerNormTest, SkipLayerNormBatch2_Skip_Broadcast_No_Batch_Size) {
int batch_size = 2;
int sequence_length = 2;