From d5c565156db023af400ff4e18780bdb2862eb343 Mon Sep 17 00:00:00 2001 From: PeixuanZuo <94887879+PeixuanZuo@users.noreply.github.com> Date: Tue, 22 Aug 2023 12:06:58 +0800 Subject: [PATCH] [ROCm] add SimplifiedSkipLayerNorm implementation (#17213) add SimplifiedSkipLayerNorm implementation --- .../contrib_ops/rocm/bert/layer_norm.cuh | 85 +++++++++++++++++++ .../contrib_ops/rocm/bert/skip_layer_norm.cc | 26 ++++-- .../contrib_ops/rocm/bert/skip_layer_norm.h | 2 +- .../rocm/bert/skip_layer_norm_impl.cu | 22 +++-- .../rocm/bert/skip_layer_norm_impl.h | 2 +- .../rocm/bert/skip_layer_norm_impl_kernel.h | 22 ++++- .../rocm/bert/skip_layer_norm_tunable_op.h | 52 ++++++------ .../contrib_ops/rocm/rocm_contrib_kernels.cc | 4 + .../kernels/rocm/skip_layer_norm.cu | 61 +++++++------ .../kernels/skip_layer_norm_test.py | 46 +++++++--- .../tools/kernel_explorer/kernels/utils.py | 5 ++ .../test/contrib_ops/skiplayernorm_op_test.cc | 6 +- 12 files changed, 249 insertions(+), 84 deletions(-) diff --git a/onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh b/onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh index 9b7dbd5291..3f9183ef10 100644 --- a/onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh +++ b/onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh @@ -109,6 +109,61 @@ __device__ inline void LayerNorm( } } +template +__device__ inline void SimplifiedLayerNorm( + const U& thread_data, const int ld, const int offset, const V* gamma, const U epsilon, V* output) { + // Assuming thread_data is already divided by ld + + using BlockReduce = hipcub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ U rsigma; // 1 / std.dev. + + const U sum = BlockReduce(temp_storage).Sum(thread_data); + + if (threadIdx.x == 0) { + rsigma = Rsqrt(sum + epsilon); + } + __syncthreads(); + + for (int i = threadIdx.x; i < ld; i += TPB) { + const int idx = offset + i; + const U val = static_cast(output[idx]); + const U g = static_cast(gamma[i]); + output[idx] = static_cast(g * val * rsigma); + } +} + +template +__device__ inline void SimplifiedLayerNormVec( + const U& thread_data, const int ld, const int offset, const V* gamma, const U epsilon, V* output) { + // Assuming thread_data is already divided by ld + using VecV = aligned_vector; + using BlockReduce = hipcub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ U rsigma; // 1 / std.dev. + + const U sum = BlockReduce(temp_storage).Sum(thread_data); + + if (threadIdx.x == 0) { + rsigma = Rsqrt(sum + epsilon); + } + __syncthreads(); + + if (ILP * threadIdx.x < ld) { + for (int i = threadIdx.x * ILP; i < ld; i += TPB * ILP) { + int idx = offset + i; + const VecV gamma_v = *reinterpret_cast(gamma + i); + VecV output_v = *reinterpret_cast(output + idx); + + #pragma unroll + for (int k = 0; k < ILP; k++) { + output_v.val[k] = U(gamma_v.val[k]) * U(output_v.val[k]) * rsigma; + } + *(reinterpret_cast(output + idx)) = output_v; + } + } +} + template __device__ inline void LayerNormVec( const hipcub::KeyValuePair& thread_data, const int ld, const int offset, const V* beta, @@ -182,6 +237,36 @@ __device__ inline void LayerNormSmall(const T* input_v, const hipcub::KeyValuePa } } +template +__device__ inline void SimplifiedLayerNormSmall(const T* input_v, const U& thread_data, const int ld, const int idx, + const V* gamma, const U epsilon, V* output) { + // Assuming thread_data is already divided by ld + // Small settings: the block covers the leading dimension TPB >= ld. The input + // value is available in a register + using VecV = aligned_vector; + using BlockReduce = hipcub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ U rsigma; // 1 / std.dev. + + const U sum = BlockReduce(temp_storage).Sum(thread_data); + + if (threadIdx.x == 0) { + rsigma = Rsqrt(sum + epsilon); + } + __syncthreads(); + + if (ILP * threadIdx.x < ld) { + const VecV gamma_v = *reinterpret_cast(gamma + threadIdx.x * ILP); + VecV output_v; + + #pragma unroll + for (int i = 0; i < ILP; i++) { + output_v.val[i] = U(gamma_v.val[i]) * U(input_v[i]) * rsigma; + } + *(reinterpret_cast(output + idx)) = output_v; + } +} + } // namespace rocm } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc index b2d35560d2..9e649fb591 100644 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc @@ -20,26 +20,36 @@ namespace rocm { kRocmExecutionProvider, \ (*KernelDefBuilder::Create()) \ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - SkipLayerNorm); + SkipLayerNorm); \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + SkipSimplifiedLayerNormalization, \ + kMSDomain, \ + 1, \ + T, \ + kRocmExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + SkipLayerNorm); REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(MLFloat16) using namespace ONNX_NAMESPACE; -template -SkipLayerNorm::SkipLayerNorm(const OpKernelInfo& op_kernel_info) : RocmKernel(op_kernel_info) { +template +SkipLayerNorm::SkipLayerNorm(const OpKernelInfo& op_kernel_info) : RocmKernel(op_kernel_info) { ORT_ENFORCE(op_kernel_info.GetAttr("epsilon", &epsilon_).IsOK()); ORT_ENFORCE(epsilon_ >= 0); } -template -Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const { +template +Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const { const Tensor* input = ctx->Input(0); const Tensor* skip = ctx->Input(1); const Tensor* gamma = ctx->Input(2); - const Tensor* beta = ctx->Input(3); - const Tensor* bias = ctx->Input(4); + + const Tensor* beta = Simplified ? nullptr : ctx->Input(3); + const Tensor* bias = Simplified ? ctx->Input(3) : ctx->Input(4); Tensor* output = ctx->Output(0, input->Shape()); @@ -102,7 +112,7 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const { int64_t element_count = input->Shape().Size(); typedef typename ToHipType::MappedType HipT; - return LaunchSkipLayerNormKernel( + return LaunchSkipLayerNormKernel( GetTuningContext(), ctx->GetComputeStream(), reinterpret_cast(output->MutableData()), diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.h index 07d7037227..02228bc59c 100644 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.h +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.h @@ -11,7 +11,7 @@ namespace rocm { using namespace onnxruntime::rocm; -template +template class SkipLayerNorm final : public RocmKernel { public: SkipLayerNorm(const OpKernelInfo& op_kernel_info); diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu index a02a4d00ab..8387c49a33 100644 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu @@ -39,7 +39,7 @@ namespace onnxruntime { namespace contrib { namespace rocm { -template +template Status LaunchSkipLayerNormKernel( RocmTuningContext* tuning_ctx, Stream* stream, V* output, T* skip_input_bias_add_output, const T* input, const T* skip, const V* gamma, const V* beta, const T* bias, float epsilon, int ld, int element_count) { @@ -50,20 +50,32 @@ Status LaunchSkipLayerNormKernel( gamma, beta, bias, epsilon, ld, element_count); if (tuning_ctx->IsTunableOpEnabled()) { - static SkipLayerNormTunableOp op; + static SkipLayerNormTunableOp op; return op(¶ms); } - return SkipLayerNormStaticSelection(¶ms); + return SkipLayerNormStaticSelection(¶ms); } -template Status LaunchSkipLayerNormKernel( +template Status LaunchSkipLayerNormKernel( RocmTuningContext* tuning_ctx, Stream* stream, float* output, float* skip_input_bias_add_output, const float* input, const float* skip, const float* gamma, const float* beta, const float* bias, float epsilon, int ld, int element_count); -template Status LaunchSkipLayerNormKernel( +template Status LaunchSkipLayerNormKernel( + RocmTuningContext* tuning_ctx, Stream* stream, half* output, half* skip_input_bias_add_output, const half* input, + const half* skip, const half* gamma, const half* beta, + const half* bias, float epsilon, int ld, + int element_count); + +template Status LaunchSkipLayerNormKernel( + RocmTuningContext* tuning_ctx, Stream* stream, float* output, float* skip_input_bias_add_output, const float* input, + const float* skip, const float* gamma, const float* beta, + const float* bias, float epsilon, int ld, + int element_count); + +template Status LaunchSkipLayerNormKernel( RocmTuningContext* tuning_ctx, Stream* stream, half* output, half* skip_input_bias_add_output, const half* input, const half* skip, const half* gamma, const half* beta, const half* bias, float epsilon, int ld, diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h index ae8ddcacca..5e2a92447d 100644 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h @@ -10,7 +10,7 @@ namespace onnxruntime { namespace contrib { namespace rocm { -template +template Status LaunchSkipLayerNormKernel( RocmTuningContext* tuning, Stream* stream, diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h index 56ed185ffe..fcfbc8969e 100644 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h @@ -23,7 +23,7 @@ half maybe2half(float x) { return __float2half_rn(x); } -template +template __global__ void SkipLayerNormKernel( const int ld, const T* input, const T* skip, const V* beta, const V* gamma, const T* bias, const U epsilon, V* output, T* skip_input_bias_add_output) { @@ -47,11 +47,16 @@ __global__ void SkipLayerNormKernel( output[idx] = static_cast(val); } + if constexpr (Simplified) { + SimplifiedLayerNorm(thread_data.value, ld, offset, gamma, epsilon, output); + return; + } + LayerNorm(thread_data, ld, offset, beta, gamma, epsilon, output); } // Vectorized kernel -template +template __global__ void SkipLayerNormKernelVec( const int ld, const T* input, const T* skip, const V* beta, const V* gamma, const T* bias, const U epsilon, V* output, T* skip_input_bias_add_output, @@ -94,11 +99,16 @@ __global__ void SkipLayerNormKernelVec( } } + if constexpr (Simplified) { + SimplifiedLayerNormVec(thread_data.value, ld, offset, gamma, epsilon, output); + return; + } + LayerNormVec(thread_data, ld, offset, beta, gamma, epsilon, output); } // Vectorized kernel -template +template __global__ void SkipLayerNormKernelSmall( const int ld, const T* input, const T* skip, const V* beta, const V* gamma, const T* bias, const U epsilon, V* output, T* skip_input_bias_add_output, @@ -138,6 +148,12 @@ __global__ void SkipLayerNormKernelSmall( thread_data = hipcub::KeyValuePair(rldval_sum, rldvalsq_sum); } + + if constexpr (Simplified) { + SimplifiedLayerNormSmall(input_v.val, thread_data.value, ld, idx, gamma, epsilon, output); + return; + } + LayerNormSmall(input_v.val, thread_data, ld, idx, beta, gamma, epsilon, output); } diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h index e85a527ea4..0391704ce1 100644 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h @@ -42,51 +42,51 @@ struct SkipLayerNormParams : OpParams { int element_count; }; -template +template Status SkipLayerNormSmallOp(const SkipLayerNormParams* params) { // Loosen the hard constraint for ld (hidden_size) to include more possible *Small kernels, // which could offer better performance in some combinations of ThreadsPerBlock and VecSize. TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( !((params->ld <= 8192 && params->ld % VecSize == 0 && params->ld <= ThreadsPerBlock * VecSize && params->ld > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize))); - SkipLayerNormKernelSmall<<element_count, params->ld)), - dim3(ThreadsPerBlock), - 0, params->StreamHandle()>>>( + SkipLayerNormKernelSmall<<element_count, params->ld)), + dim3(ThreadsPerBlock), + 0, params->StreamHandle()>>>( params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, static_cast(params->epsilon), params->output, params->skip_input_bias_add_output, (params->bias == nullptr) ? false : true, (params->skip_input_bias_add_output == nullptr) ? false : true); return HIP_CALL(hipGetLastError()); } -template +template Status SkipLayerNormRegularOp(const SkipLayerNormParams* params) { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( !((params->ld > 0 && params->ld % VecSize == 0 && (params->ld >= ThreadsPerBlock * VecSize || (params->ld < GPU_WARP_SIZE && params->ld > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize))))); - SkipLayerNormKernelVec<<element_count, params->ld)), - dim3(ThreadsPerBlock), - 0, params->StreamHandle()>>>( + SkipLayerNormKernelVec<<element_count, params->ld)), + dim3(ThreadsPerBlock), + 0, params->StreamHandle()>>>( params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, static_cast(params->epsilon), params->output, params->skip_input_bias_add_output, (params->bias == nullptr) ? false : true, (params->skip_input_bias_add_output == nullptr) ? false : true); return HIP_CALL(hipGetLastError()); } -template +template Status SkipLayerNormStaticSelection(const SkipLayerNormParams* params) { bool hasBias = (params->bias == nullptr) ? false : true; bool hasSkipInputBiasAdditionOutput = (params->skip_input_bias_add_output == nullptr) ? false : true; const int grid_size = params->element_count / params->ld; const int block_size = 256; -#define LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(ELEMENTS, TPB, ILP) \ - if (params->ld <= ELEMENTS) { \ - SkipLayerNormKernelSmall<<StreamHandle()>>>( \ - params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, \ - static_cast(params->epsilon), params->output, params->skip_input_bias_add_output, \ - hasBias, hasSkipInputBiasAdditionOutput); \ - break; \ +#define LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(ELEMENTS, TPB, ILP) \ + if (params->ld <= ELEMENTS) { \ + SkipLayerNormKernelSmall<<StreamHandle()>>>( \ + params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, \ + static_cast(params->epsilon), params->output, params->skip_input_bias_add_output, \ + hasBias, hasSkipInputBiasAdditionOutput); \ + break; \ } if (0 == (params->ld % 4)) { do { @@ -97,7 +97,7 @@ Status SkipLayerNormStaticSelection(const SkipLayerNormParams* params) { LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(768, 192, 4) LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(1024, 256, 4) - SkipLayerNormKernel<<StreamHandle()>>>( + SkipLayerNormKernel<<StreamHandle()>>>( params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, static_cast(params->epsilon), params->output, params->skip_input_bias_add_output); } while (0); @@ -108,7 +108,7 @@ Status SkipLayerNormStaticSelection(const SkipLayerNormParams* params) { LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(128, 128, 1) LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(384, 384, 1) - SkipLayerNormKernel<<StreamHandle()>>>( + SkipLayerNormKernel<<StreamHandle()>>>( params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, static_cast(params->epsilon), params->output, params->skip_input_bias_add_output); } while (0); @@ -116,12 +116,12 @@ Status SkipLayerNormStaticSelection(const SkipLayerNormParams* params) { return HIP_CALL(hipPeekAtLastError()); } // namespace rocm -#define ADD_OP_FOR_ALL_VEC_SIZE(name, threads_per_block) \ - this->RegisterOp(name); \ - this->RegisterOp(name); \ - this->RegisterOp(name); \ - this->RegisterOp(name); \ - this->RegisterOp(name); +#define ADD_OP_FOR_ALL_VEC_SIZE(name, threads_per_block) \ + this->RegisterOp(name); \ + this->RegisterOp(name); \ + this->RegisterOp(name); \ + this->RegisterOp(name); \ + this->RegisterOp(name); #define ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(name) \ ADD_OP_FOR_ALL_VEC_SIZE(name, 64) \ @@ -140,11 +140,11 @@ Status SkipLayerNormStaticSelection(const SkipLayerNormParams* params) { ADD_OP_FOR_ALL_VEC_SIZE(name, 896) \ ADD_OP_FOR_ALL_VEC_SIZE(name, 1024) -template +template class SkipLayerNormTunableOp : public TunableOp> { public: SkipLayerNormTunableOp() { - this->RegisterOp(SkipLayerNormStaticSelection); + this->RegisterOp(SkipLayerNormStaticSelection); ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormSmallOp) ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormRegularOp) diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc index 953601e999..7bc0f99081 100644 --- a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc @@ -84,6 +84,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SkipSimplifiedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ThresholdedRelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ThresholdedRelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ThresholdedRelu); @@ -219,6 +221,8 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { // BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.cu index 55b79141f1..ec353f7e91 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.cu @@ -12,7 +12,7 @@ namespace py = pybind11; namespace onnxruntime { -template +template class SkipLayerNormSmall : public IKernelExplorer { public: SkipLayerNormSmall(DeviceArray& output, DeviceArray& skip_input_bias_add_output, DeviceArray& input, DeviceArray& skip, @@ -23,11 +23,11 @@ class SkipLayerNormSmall : public IKernelExplorer { static_cast(beta.ptr()), static_cast(bias.ptr()), epsilon, hidden_size, element_count) {} void Run() override { - ORT_THROW_IF_ERROR((contrib::rocm::SkipLayerNormSmallOp(¶ms_))); + ORT_THROW_IF_ERROR((contrib::rocm::SkipLayerNormSmallOp(¶ms_))); } bool IsSupported() { - Status status = contrib::rocm::SkipLayerNormSmallOp(¶ms_); + Status status = contrib::rocm::SkipLayerNormSmallOp(¶ms_); return status.IsOK(); } @@ -36,7 +36,7 @@ class SkipLayerNormSmall : public IKernelExplorer { ParamsT params_{}; }; -template +template class SkipLayerNormRegular : public IKernelExplorer { public: SkipLayerNormRegular(DeviceArray& output, DeviceArray& skip_input_bias_add_output, DeviceArray& input, DeviceArray& skip, @@ -47,11 +47,11 @@ class SkipLayerNormRegular : public IKernelExplorer { static_cast(beta.ptr()), static_cast(bias.ptr()), epsilon, hidden_size, element_count) {} void Run() override { - ORT_THROW_IF_ERROR((contrib::rocm::SkipLayerNormRegularOp(¶ms_))); + ORT_THROW_IF_ERROR((contrib::rocm::SkipLayerNormRegularOp(¶ms_))); } bool IsSupported() { - Status status = contrib::rocm::SkipLayerNormRegularOp(¶ms_); + Status status = contrib::rocm::SkipLayerNormRegularOp(¶ms_); return status.IsOK(); } @@ -60,7 +60,7 @@ class SkipLayerNormRegular : public IKernelExplorer { ParamsT params_{}; }; -template +template class SkipLayerNormStaticSelection : public IKernelExplorer { public: SkipLayerNormStaticSelection(DeviceArray& output, DeviceArray& skip_input_bias_add_output, DeviceArray& input, @@ -71,11 +71,11 @@ class SkipLayerNormStaticSelection : public IKernelExplorer { static_cast(beta.ptr()), static_cast(bias.ptr()), epsilon, hidden_size, element_count) {} void Run() override { - ORT_THROW_IF_ERROR((contrib::rocm::SkipLayerNormStaticSelection(¶ms_))); + ORT_THROW_IF_ERROR((contrib::rocm::SkipLayerNormStaticSelection(¶ms_))); } bool IsSupported() { - Status status = contrib::rocm::SkipLayerNormStaticSelection(¶ms_); + Status status = contrib::rocm::SkipLayerNormStaticSelection(¶ms_); return status.IsOK(); } @@ -84,7 +84,7 @@ class SkipLayerNormStaticSelection : public IKernelExplorer { ParamsT params_{}; }; -template +template class SkipLayerNormTunable : public IKernelExplorer { public: SkipLayerNormTunable(DeviceArray& output, DeviceArray& skip_input_bias_add_output, DeviceArray& input, DeviceArray& skip, @@ -107,18 +107,22 @@ class SkipLayerNormTunable : public IKernelExplorer { private: using ParamsT = contrib::rocm::SkipLayerNormParams; ParamsT params_{}; - contrib::rocm::SkipLayerNormTunableOp op_{}; + contrib::rocm::SkipLayerNormTunableOp op_{}; }; -#define REGISTER_OP(name, type, threads_per_block, vec_size) \ - py::class_>(m, #name "_" #type "_" #threads_per_block "_" #vec_size) \ - .def(py::init()) \ - .def("SetRepeats", &name::SetRepeats) \ - .def("Profile", &name::Profile) \ - .def("Run", &name::Run) \ - .def("IsSupported", &name::IsSupported); +#define REGISTER_OP_COMMON(name, type, ...) \ + py::class_>(m, name) \ + .def(py::init()) \ + .def("SetRepeats", &type<__VA_ARGS__>::SetRepeats) \ + .def("Profile", &type<__VA_ARGS__>::Profile) \ + .def("Run", &type<__VA_ARGS__>::Run) \ + .def("IsSupported", &type<__VA_ARGS__>::IsSupported); + +#define REGISTER_OP(name, type, threads_per_block, vec_size) \ + REGISTER_OP_COMMON("Simplified" #name "_" #type "_" #threads_per_block "_" #vec_size, name, type, threads_per_block, vec_size, true) \ + REGISTER_OP_COMMON(#name "_" #type "_" #threads_per_block "_" #vec_size, name, type, threads_per_block, vec_size, false) #define REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, threads_per_block) \ REGISTER_OP(name, type, threads_per_block, 1) \ @@ -144,19 +148,24 @@ class SkipLayerNormTunable : public IKernelExplorer { REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 896) \ REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 1024) -#define REGISTER_OP_TYPED(name, type) \ - py::class_>(m, #name "_" #type) \ +#define REGISTER_COMMON(name, type, ...) \ + py::class_>(m, name) \ .def(py::init()) \ - .def("SetRepeats", &name::SetRepeats) \ - .def("Profile", &name::Profile) \ - .def("Run", &name::Run) \ - .def("IsSupported", &name::IsSupported); + .def("SetRepeats", &type<__VA_ARGS__>::SetRepeats) \ + .def("Profile", &type<__VA_ARGS__>::Profile) \ + .def("Run", &type<__VA_ARGS__>::Run) \ + .def("IsSupported", &type<__VA_ARGS__>::IsSupported); + +#define REGISTER_OP_TYPED(name, type) \ + REGISTER_COMMON("Simplified" #name "_" #type, name, type, true) \ + REGISTER_COMMON(#name "_" #type, name, type, false) KE_REGISTER(m) { REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormSmall, half); REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormSmall, float); + REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormRegular, half); REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormRegular, float); diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/skip_layer_norm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/skip_layer_norm_test.py index ddd5a53506..a31e8b851f 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/skip_layer_norm_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/skip_layer_norm_test.py @@ -11,7 +11,7 @@ from itertools import product import kernel_explorer as ke import numpy as np import pytest -from utils import dtype_to_bytes, standardization +from utils import dtype_to_bytes, root_mean_square, standardization def get_bert_sizes_test(): @@ -28,10 +28,11 @@ def get_bert_sizes_profile(): return product(batch_sizes, seq_lens, hidden_sizes) -def dtype_to_funcs(dtype): +def dtype_to_funcs(dtype, simplified=False): + skip_layer_norm_prefix = "SimplifiedSkipLayerNorm" if simplified else "SkipLayerNorm" type_map = { - "float16": list(filter(lambda x: re.search("SkipLayerNorm.*_half", x), dir(ke))), - "float32": list(filter(lambda x: re.search("SkipLayerNorm.*_float", x), dir(ke))), + "float16": list(filter(lambda x: re.match(f"{skip_layer_norm_prefix}.*_half.*", x), dir(ke))), + "float32": list(filter(lambda x: re.match(f"{skip_layer_norm_prefix}.*_float.*", x), dir(ke))), } return type_map[dtype] @@ -43,7 +44,16 @@ def skip_layer_norm(input_x, skip, bias, gamma, beta, epsilon): return output, val -def run_skip_layer_norm(batch_size: int, seq_len: int, hidden_size: int, dtype: str, func, has_optional_output=False): +def simplified_skip_layer_norm(input_x, skip, bias, gamma, epsilon): + val = input_x + skip + bias + rms = root_mean_square(val, 2, epsilon) + output = (val / rms) * gamma + return output, val + + +def run_skip_layer_norm( + batch_size: int, seq_len: int, hidden_size: int, dtype: str, func, simplified=False, has_optional_output=False +): np.random.seed(0) input_x = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype) skip = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype) @@ -86,20 +96,25 @@ def run_skip_layer_norm(batch_size: int, seq_len: int, hidden_size: int, dtype: y_d.UpdateHostNumpyArray() optional_d.UpdateHostNumpyArray() - y_ref, y_optional = skip_layer_norm(input_x, skip, bias, gamma, beta, epsilon) + if simplified: + y_ref, y_optional = simplified_skip_layer_norm(input_x, skip, bias, gamma, epsilon) + else: + y_ref, y_optional = skip_layer_norm(input_x, skip, bias, gamma, beta, epsilon) np.testing.assert_almost_equal(y_ref, output_y, decimal=1) if has_optional_output: np.testing.assert_almost_equal(y_optional, output_optional, decimal=3) dtypes = ["float32", "float16"] +simplified = [True, False] @pytest.mark.parametrize("bert_sizes", get_bert_sizes_test()) @pytest.mark.parametrize("dtype", dtypes) -def test_skip_layer_norm(bert_sizes, dtype): - for func in dtype_to_funcs(dtype): - run_skip_layer_norm(*bert_sizes, dtype, func) +@pytest.mark.parametrize("simplified", simplified) +def test_skip_layer_norm(bert_sizes, dtype, simplified): + for func in dtype_to_funcs(dtype, simplified): + run_skip_layer_norm(*bert_sizes, dtype, func, simplified) @dataclass @@ -160,9 +175,9 @@ def profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func, ke.report(SkipLayerNormMetric(func, dtype, duration_ms, total_bytes, batch_size, seq_len, hidden_size)) -def profile_with_args(batch_size, seq_len, hidden_size, dtype, sort=True, has_optional_output=False): +def profile_with_args(batch_size, seq_len, hidden_size, dtype, sort=True, has_optional_output=False, simplified=False): with ke.benchmark(sort): - for func in dtype_to_funcs(dtype): + for func in dtype_to_funcs(dtype, simplified): profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func, has_optional_output) @@ -184,11 +199,18 @@ if __name__ == "__main__": group.add_argument("dtype", choices=dtypes) group.add_argument("--sort", action="store_true") group.add_argument("--has_optional_output", "-o", action="store_true") + group.add_argument("--simplified", "-s", action="store_true", default=False) if len(sys.argv) == 1: profile() else: args = parser.parse_args() profile_with_args( - args.batch_size, args.seq_len, args.hidden_size, args.dtype, args.sort, args.has_optional_output + args.batch_size, + args.seq_len, + args.hidden_size, + args.dtype, + args.sort, + args.has_optional_output, + args.simplified, ) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/utils.py b/onnxruntime/python/tools/kernel_explorer/kernels/utils.py index f595ab1c68..4901174373 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/utils.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/utils.py @@ -133,6 +133,11 @@ def relu(x, bias): return np.max(x, 0, keepdims=True) +def root_mean_square(x, axis, epsilon): + rms = np.sqrt(np.mean(np.square(x), axis=axis, keepdims=True) + epsilon) + return rms + + def standardization(x, axis, epsilon): mean = np.mean(x, axis=axis, keepdims=True) variance = np.var(x, axis=axis, keepdims=True) diff --git a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc index a41a1dd4ec..2395532198 100644 --- a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc +++ b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc @@ -698,8 +698,8 @@ TEST(SkipLayerNormTest, SkipLayerNormBatch2_TokenCount) { true); } -// SkipSimplifiedLayerNorm has not been enabled for ROCm and DML yet -#if !defined(USE_ROCM) && !defined(USE_DML) +// SkipSimplifiedLayerNorm has not been enabled for DML yet +#if !defined(USE_DML) TEST(SkipLayerNormTest, SkipSimplifiedLayerNormBatch1_Float16) { int batch_size = 1; int sequence_length = 2; @@ -735,7 +735,9 @@ TEST(SkipLayerNormTest, SkipSimplifiedLayerNormBatch1_Float16) { true, true); } +#endif +#if !defined(USE_ROCM) && !defined(USE_DML) TEST(SkipLayerNormTest, SkipLayerNormBatch2_Skip_Broadcast_No_Batch_Size) { int batch_size = 2; int sequence_length = 2;